diff --git a/bin/pyspark b/bin/pyspark
index 6655725ef8e8e..96f30a260a09e 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -50,22 +50,47 @@ fi
. "$FWDIR"/bin/load-spark-env.sh
-# Figure out which Python executable to use
+# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython`
+# executable, while the worker would still be launched using PYSPARK_PYTHON.
+#
+# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added
+# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver.
+# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set
+# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
+# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
+# and executor Python executables.
+#
+# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables.
+
+# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set:
+if hash python2.7 2>/dev/null; then
+ # Attempt to use Python 2.7, if installed:
+ DEFAULT_PYTHON="python2.7"
+else
+ DEFAULT_PYTHON="python"
+fi
+
+# Determine the Python executable to use for the driver:
+if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then
+ # If IPython options are specified, assume user wants to run IPython
+ # (for backwards-compatibility)
+ PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS"
+ PYSPARK_DRIVER_PYTHON="ipython"
+elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then
+ PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}"
+fi
+
+# Determine the Python executable to use for the executors:
if [[ -z "$PYSPARK_PYTHON" ]]; then
- if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then
- # for backward compatibility
- PYSPARK_PYTHON="ipython"
+ if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then
+ echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2
+ exit 1
else
- PYSPARK_PYTHON="python"
+ PYSPARK_PYTHON="$DEFAULT_PYTHON"
fi
fi
export PYSPARK_PYTHON
-if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then
- # for backward compatibility
- PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS"
-fi
-
# Add the PySpark classes to the Python path:
export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH"
export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
@@ -93,9 +118,9 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
if [[ -n "$PYSPARK_DOC_TEST" ]]; then
- exec "$PYSPARK_PYTHON" -m doctest $1
+ exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
else
- exec "$PYSPARK_PYTHON" $1
+ exec "$PYSPARK_DRIVER_PYTHON" $1
fi
exit
fi
@@ -111,5 +136,5 @@ if [[ "$1" =~ \.py$ ]]; then
else
# PySpark shell requires special handling downstream
export PYSPARK_SHELL=1
- exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS
+ exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS
fi
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 396cdd1247e07..b709b8880ba76 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,6 +21,7 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
+import java.util.Arrays
import java.util.concurrent.atomic.AtomicInteger
import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
@@ -1429,7 +1430,10 @@ object SparkContext extends Logging {
simpleWritableConverter[Boolean, BooleanWritable](_.get)
implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
- simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
+ simpleWritableConverter[Array[Byte], BytesWritable](bw =>
+ // getBytes method returns array which is longer then data to be returned
+ Arrays.copyOfRange(bw.getBytes, 0, bw.getLength)
+ )
}
implicit def stringWritableConverter(): WritableConverter[String] =
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 8ca731038e528..e72826dc25f41 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -26,6 +26,8 @@ import scala.collection.JavaConversions._
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
import com.google.common.io.Files
+import org.apache.spark.util.Utils
+
/**
* Utilities for tests. Included in main codebase since it's used by multiple
* projects.
@@ -42,8 +44,7 @@ private[spark] object TestUtils {
* in order to avoid interference between tests.
*/
def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = {
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value)
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
createJar(files, jarFile)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index c74f86548ef85..4acbdf9d5e25f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
-import scala.reflect.ClassTag
-import scala.util.{Try, Success, Failure}
import net.razorvine.pickle.{Pickler, Unpickler}
@@ -42,7 +40,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD(
- parent: RDD[_],
+ @transient parent: RDD[_],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
@@ -55,9 +53,9 @@ private[spark] class PythonRDD(
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
- override def getPartitions = parent.partitions
+ override def getPartitions = firstParent.partitions
- override val partitioner = if (preservePartitoning) parent.partitioner else None
+ override val partitioner = if (preservePartitoning) firstParent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
@@ -234,7 +232,7 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
- PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 71bdf0fe1b917..e314408c067e9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Create and start the worker
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
val worker = pb.start()
// Redirect worker stdout and stderr
@@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
try {
// Create and start the daemon
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
daemon = pb.start()
val in = new DataInputStream(daemon.getInputStream)
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 79b4d7ea41a33..af94b05ce3847 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -34,7 +34,8 @@ object PythonRunner {
val pythonFile = args(0)
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
- val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf
+ val pythonExec =
+ sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python"))
// Format python file paths before adding them to the PYTHONPATH
val formattedPythonFile = formatPath(pythonFile)
@@ -57,6 +58,7 @@ object PythonRunner {
val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs)
val env = builder.environment()
env.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
index a4409181ec907..4c9ca97a2a6b7 100644
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
@@ -66,13 +66,27 @@ sealed abstract class ManagedBuffer {
final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
extends ManagedBuffer {
+ /**
+ * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
+ * Avoid unless there's a good reason not to.
+ */
+ private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
+
override def size: Long = length
override def nioByteBuffer(): ByteBuffer = {
var channel: FileChannel = null
try {
channel = new RandomAccessFile(file, "r").getChannel
- channel.map(MapMode.READ_ONLY, offset, length)
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < MIN_MEMORY_MAP_BYTES) {
+ val buf = ByteBuffer.allocate(length.toInt)
+ channel.read(buf, offset)
+ buf.flip()
+ buf
+ } else {
+ channel.map(MapMode.READ_ONLY, offset, length)
+ }
} catch {
case e: IOException =>
Try(channel.size).toOption match {
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 6b00190c5eccc..9396b6ba84e7e 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -748,9 +748,7 @@ private[nio] class ConnectionManager(
} catch {
case e: Exception => {
logError(s"Exception was thrown while processing message", e)
- val m = Message.createBufferMessage(bufferMessage.id)
- m.hasError = true
- ackMessage = Some(m)
+ ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id))
}
} finally {
sendMessage(connectionManagerId, ackMessage.getOrElse {
@@ -913,8 +911,12 @@ private[nio] class ConnectionManager(
}
case scala.util.Success(ackMessage) =>
if (ackMessage.hasError) {
+ val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head
+ val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit())
+ errorMsgByteBuf.get(errorMsgBytes)
+ val errorMsg = new String(errorMsgBytes, "utf-8")
val e = new IOException(
- "sendMessageReliably failed with ACK that signalled a remote error")
+ s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg")
if (!promise.tryFailure(e)) {
logWarning("Ignore error because promise is completed", e)
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
index 0b874c2891255..3ad04591da658 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.util.Utils
private[nio] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@@ -84,6 +85,19 @@ private[nio] object Message {
createBufferMessage(new Array[ByteBuffer](0), ackId)
}
+ /**
+ * Create a "negative acknowledgment" to notify a sender that an error occurred
+ * while processing its message. The exception's stacktrace will be formatted
+ * as a string, serialized into a byte array, and sent as the message payload.
+ */
+ def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
+ val exceptionString = Utils.exceptionString(exception)
+ val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8"))
+ val errorMessage = createBufferMessage(serializedExceptionString, ackId)
+ errorMessage.hasError = true
+ errorMessage
+ }
+
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
case BUFFER_MESSAGE => new BufferMessage(header.id,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index b389b9a2022c6..5add4fc433fb3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -151,17 +151,14 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
} catch {
case e: Exception => {
logError("Exception handling buffer message", e)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
+ Some(Message.createErrorMessage(e, msg.id))
}
}
case otherMessage: Any =>
- logError("Unknown type message received: " + otherMessage)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
+ val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}"
+ logError(errorMsg)
+ Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id))
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index 2987dc04494a5..f0e43fbf70976 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -71,19 +71,19 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr
{k}
{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}
- {UIUtils.formatDuration(v.taskTime)}
+ {UIUtils.formatDuration(v.taskTime)}
{v.failedTasks + v.succeededTasks}
{v.failedTasks}
{v.succeededTasks}
-
+
{Utils.bytesToString(v.inputBytes)}
-
+
{Utils.bytesToString(v.shuffleRead)}
-
+
{Utils.bytesToString(v.shuffleWrite)}
-
+
{Utils.bytesToString(v.memoryBytesSpilled)}
-
+
{Utils.bytesToString(v.diskBytesSpilled)}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 2e67310594784..4ee7f08ab47a2 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -176,9 +176,9 @@ private[ui] class StageTableBase(
{makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size,
stageData.numFailedTasks, s.numTasks)}
- {inputReadWithUnit}
- {shuffleReadWithUnit}
- {shuffleWriteWithUnit}
+ {inputReadWithUnit}
+ {shuffleReadWithUnit}
+ {shuffleWriteWithUnit}
}
/** Render an HTML row that represents a stage */
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
index 716591c9ed449..83489ca0679ee 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
@@ -58,9 +58,9 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
{rdd.numCachedPartitions}
{"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)}
- {Utils.bytesToString(rdd.memSize)}
- {Utils.bytesToString(rdd.tachyonSize)}
- {Utils.bytesToString(rdd.diskSize)}
+ {Utils.bytesToString(rdd.memSize)}
+ {Utils.bytesToString(rdd.tachyonSize)}
+ {Utils.bytesToString(rdd.diskSize)}
// scalastyle:on
}
diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala
index 6d1fc05a15d2c..fdc73f08261a6 100644
--- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala
+++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala
@@ -51,12 +51,27 @@ private[spark] class FileLogger(
def this(
logDir: String,
sparkConf: SparkConf,
- compress: Boolean = false,
- overwrite: Boolean = true) = {
+ compress: Boolean,
+ overwrite: Boolean) = {
this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress,
overwrite = overwrite)
}
+ def this(
+ logDir: String,
+ sparkConf: SparkConf,
+ compress: Boolean) = {
+ this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = compress,
+ overwrite = true)
+ }
+
+ def this(
+ logDir: String,
+ sparkConf: SparkConf) = {
+ this(logDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf), compress = false,
+ overwrite = true)
+ }
+
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 3d307b3c16d3e..07477dd460a4b 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -168,6 +168,20 @@ private[spark] object Utils extends Logging {
private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
+ // Add a shutdown hook to delete the temp dirs when the JVM exits
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") {
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ logDebug("Shutdown hook called")
+ shutdownDeletePaths.foreach { dirPath =>
+ try {
+ Utils.deleteRecursively(new File(dirPath))
+ } catch {
+ case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e)
+ }
+ }
+ }
+ })
+
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
val absolutePath = file.getAbsolutePath()
@@ -252,14 +266,6 @@ private[spark] object Utils extends Logging {
}
registerShutdownDeleteDir(dir)
-
- // Add a shutdown hook to delete the temp dir when the JVM exits
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
- override def run() {
- // Attempt to delete if some patch which is parent of this is not already registered.
- if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
- }
- })
dir
}
@@ -666,15 +672,30 @@ private[spark] object Utils extends Logging {
*/
def deleteRecursively(file: File) {
if (file != null) {
- if (file.isDirectory() && !isSymlink(file)) {
- for (child <- listFilesSafely(file)) {
- deleteRecursively(child)
+ try {
+ if (file.isDirectory && !isSymlink(file)) {
+ var savedIOException: IOException = null
+ for (child <- listFilesSafely(file)) {
+ try {
+ deleteRecursively(child)
+ } catch {
+ // In case of multiple exceptions, only last one will be thrown
+ case ioe: IOException => savedIOException = ioe
+ }
+ }
+ if (savedIOException != null) {
+ throw savedIOException
+ }
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.remove(file.getAbsolutePath)
+ }
}
- }
- if (!file.delete()) {
- // Delete can also fail if the file simply did not exist
- if (file.exists()) {
- throw new IOException("Failed to delete: " + file.getAbsolutePath)
+ } finally {
+ if (!file.delete()) {
+ // Delete can also fail if the file simply did not exist
+ if (file.exists()) {
+ throw new IOException("Failed to delete: " + file.getAbsolutePath)
+ }
}
}
}
@@ -713,7 +734,7 @@ private[spark] object Utils extends Logging {
*/
def doesDirectoryContainAnyNewFiles(dir: File, cutoff: Long): Boolean = {
if (!dir.isDirectory) {
- throw new IllegalArgumentException("$dir is not a directory!")
+ throw new IllegalArgumentException(s"$dir is not a directory!")
}
val filesAndDirs = dir.listFiles()
val cutoffTimeInMillis = System.currentTimeMillis - (cutoff * 1000)
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 7e18f45de7b5b..a8867020e457d 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark
import java.io._
import java.util.jar.{JarEntry, JarOutputStream}
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
@@ -41,8 +40,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
override def beforeAll() {
super.beforeAll()
- tmpDir = Files.createTempDir()
- tmpDir.deleteOnExit()
+ tmpDir = Utils.createTempDir()
val testTempDir = new File(tmpDir, "test")
testTempDir.mkdir()
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 4a53d25012ad9..a2b74c4419d46 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -21,7 +21,6 @@ import java.io.{File, FileWriter}
import scala.io.Source
-import com.google.common.io.Files
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.DefaultCodec
import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat}
@@ -39,8 +38,7 @@ class FileSuite extends FunSuite with LocalSparkContext {
override def beforeEach() {
super.beforeEach()
- tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ tempDir = Utils.createTempDir()
}
override def afterEach() {
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
new file mode 100644
index 0000000000000..31edad1c56c73
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.io.BytesWritable
+
+class SparkContextSuite extends FunSuite {
+ //Regression test for SPARK-3121
+ test("BytesWritable implicit conversion is correct") {
+ val bytesWritable = new BytesWritable()
+ val inputArray = (1 to 10).map(_.toByte).toArray
+ bytesWritable.set(inputArray, 0, 10)
+ bytesWritable.set(inputArray, 0, 5)
+
+ val converter = SparkContext.bytesWritableConverter()
+ val byteArray = converter.convert(bytesWritable)
+ assert(byteArray.length === 5)
+
+ bytesWritable.set(inputArray, 0, 0)
+ val byteArray2 = converter.convert(bytesWritable)
+ assert(byteArray2.length === 0)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 4cba90e8f2afe..1cdf50d5c08c7 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.deploy.SparkSubmit._
import org.apache.spark.util.Utils
import org.scalatest.FunSuite
import org.scalatest.Matchers
-import com.google.common.io.Files
class SparkSubmitSuite extends FunSuite with Matchers {
def beforeAll() {
@@ -332,7 +331,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
}
def forConfDir(defaults: Map[String, String]) (f: String => Unit) = {
- val tmpDir = Files.createTempDir()
+ val tmpDir = Utils.createTempDir()
val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf")
val writer = new OutputStreamWriter(new FileOutputStream(defaultsConf))
diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
index d5ebfb3f3fae1..12d1c7b2faba6 100644
--- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
@@ -23,8 +23,6 @@ import java.io.FileOutputStream
import scala.collection.immutable.IndexedSeq
-import com.google.common.io.Files
-
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
@@ -66,9 +64,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll {
* 3) Does the contents be the same.
*/
test("Correctness of WholeTextFileRecordReader.") {
-
- val dir = Files.createTempDir()
- dir.deleteOnExit()
+ val dir = Utils.createTempDir()
println(s"Local disk address is ${dir.toString}.")
WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
diff --git a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
index 9f49587cdc670..b70734dfe37cf 100644
--- a/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala
@@ -27,6 +27,7 @@ import scala.language.postfixOps
import org.scalatest.FunSuite
import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.util.Utils
/**
* Test the ConnectionManager with various security settings.
@@ -236,7 +237,7 @@ class ConnectionManagerSuite extends FunSuite {
val manager = new ConnectionManager(0, conf, securityManager)
val managerServer = new ConnectionManager(0, conf, securityManager)
managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- throw new Exception
+ throw new Exception("Custom exception text")
})
val size = 10 * 1024 * 1024
@@ -246,9 +247,10 @@ class ConnectionManagerSuite extends FunSuite {
val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
- intercept[IOException] {
+ val exception = intercept[IOException] {
Await.result(future, 1 second)
}
+ assert(Utils.exceptionString(exception).contains("Custom exception text"))
manager.stop()
managerServer.stop()
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 75b01191901b8..3620e251cc139 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -24,13 +24,14 @@ import org.apache.hadoop.util.Progressable
import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random
-import com.google.common.io.Files
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter,
OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter,
TaskAttemptContext => NewTaskAttempContext}
import org.apache.spark.{Partitioner, SharedSparkContext}
import org.apache.spark.SparkContext._
+import org.apache.spark.util.Utils
+
import org.scalatest.FunSuite
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
@@ -381,14 +382,16 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
}
test("zero-partition RDD") {
- val emptyDir = Files.createTempDir()
- emptyDir.deleteOnExit()
- val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.partitions.size == 0)
- assert(file.collect().toList === Nil)
- // Test that a shuffle on the file works, because this used to be a bug
- assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
- emptyDir.delete()
+ val emptyDir = Utils.createTempDir()
+ try {
+ val file = sc.textFile(emptyDir.getAbsolutePath)
+ assert(file.partitions.isEmpty)
+ assert(file.collect().toList === Nil)
+ // Test that a shuffle on the file works, because this used to be a bug
+ assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
+ } finally {
+ Utils.deleteRecursively(emptyDir)
+ }
}
test("keys and values") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
index 3efa85431876b..abc300fcffaf9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.scheduler
import scala.collection.mutable
import scala.io.Source
-import com.google.common.io.Files
import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.jackson.JsonMethods._
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -51,8 +50,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter {
private var logDirPath: Path = _
before {
- testDir = Files.createTempDir()
- testDir.deleteOnExit()
+ testDir = Utils.createTempDir()
logDirPath = Utils.getFilePath(testDir, "spark-events")
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 48114feee6233..e05f373392d4a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.scheduler
import java.io.{File, PrintWriter}
-import com.google.common.io.Files
import org.json4s.jackson.JsonMethods._
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -39,8 +38,7 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter {
private var testDir: File = _
before {
- testDir = Files.createTempDir()
- testDir.deleteOnExit()
+ testDir = Utils.createTempDir()
}
after {
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index e4522e00a622d..bc5c74c126b74 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,22 +19,13 @@ package org.apache.spark.storage
import java.io.{File, FileWriter}
-import org.apache.spark.network.nio.NioBlockTransferService
-import org.apache.spark.shuffle.hash.HashShuffleManager
-
-import scala.collection.mutable
import scala.language.reflectiveCalls
-import akka.actor.Props
-import com.google.common.io.Files
import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
import org.apache.spark.SparkConf
-import org.apache.spark.scheduler.LiveListenerBus
-import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.util.{AkkaUtils, Utils}
-import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.util.Utils
class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
private val testConf = new SparkConf(false)
@@ -48,10 +39,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
override def beforeAll() {
super.beforeAll()
- rootDir0 = Files.createTempDir()
- rootDir0.deleteOnExit()
- rootDir1 = Files.createTempDir()
- rootDir1.deleteOnExit()
+ rootDir0 = Utils.createTempDir()
+ rootDir1 = Utils.createTempDir()
rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
}
diff --git a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
index c3dd156b40514..72466a3aa1130 100644
--- a/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileLoggerSuite.scala
@@ -21,7 +21,6 @@ import java.io.{File, IOException}
import scala.io.Source
-import com.google.common.io.Files
import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -44,7 +43,7 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter {
private var logDirPathString: String = _
before {
- testDir = Files.createTempDir()
+ testDir = Utils.createTempDir()
logDirPath = Utils.getFilePath(testDir, "test-file-logger")
logDirPathString = logDirPath.toString
}
@@ -75,13 +74,13 @@ class FileLoggerSuite extends FunSuite with BeforeAndAfter {
test("Logging when directory already exists") {
// Create the logging directory multiple times
- new FileLogger(logDirPathString, new SparkConf, overwrite = true).start()
- new FileLogger(logDirPathString, new SparkConf, overwrite = true).start()
- new FileLogger(logDirPathString, new SparkConf, overwrite = true).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = true).start()
// If overwrite is not enabled, an exception should be thrown
intercept[IOException] {
- new FileLogger(logDirPathString, new SparkConf, overwrite = false).start()
+ new FileLogger(logDirPathString, new SparkConf, compress = false, overwrite = false).start()
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index e63d9d085e385..0344da60dae66 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -112,7 +112,7 @@ class UtilsSuite extends FunSuite {
}
test("reading offset bytes of a file") {
- val tmpDir2 = Files.createTempDir()
+ val tmpDir2 = Utils.createTempDir()
tmpDir2.deleteOnExit()
val f1Path = tmpDir2 + "/f1"
val f1 = new FileOutputStream(f1Path)
@@ -141,7 +141,7 @@ class UtilsSuite extends FunSuite {
}
test("reading offset bytes across multiple files") {
- val tmpDir = Files.createTempDir()
+ val tmpDir = Utils.createTempDir()
tmpDir.deleteOnExit()
val files = (1 to 3).map(i => new File(tmpDir, i.toString))
Files.write("0123456789", files(0), Charsets.UTF_8)
@@ -308,4 +308,28 @@ class UtilsSuite extends FunSuite {
}
}
+ test("deleteRecursively") {
+ val tempDir1 = Utils.createTempDir()
+ assert(tempDir1.exists())
+ Utils.deleteRecursively(tempDir1)
+ assert(!tempDir1.exists())
+
+ val tempDir2 = Utils.createTempDir()
+ val tempFile1 = new File(tempDir2, "foo.txt")
+ Files.touch(tempFile1)
+ assert(tempFile1.exists())
+ Utils.deleteRecursively(tempFile1)
+ assert(!tempFile1.exists())
+
+ val tempDir3 = new File(tempDir2, "subdir")
+ assert(tempDir3.mkdir())
+ val tempFile2 = new File(tempDir3, "bar.txt")
+ Files.touch(tempFile2)
+ assert(tempFile2.exists())
+ Utils.deleteRecursively(tempDir2)
+ assert(!tempDir2.exists())
+ assert(!tempDir3.exists())
+ assert(!tempFile2.exists())
+ }
+
}
diff --git a/docs/_config.yml b/docs/_config.yml
index 78c92281a49d5..f4bf242ac191b 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -1,5 +1,7 @@
highlighter: pygments
markdown: kramdown
+gems:
+ - jekyll-redirect-from
# For some reason kramdown seems to behave differently on different
# OS/packages wrt encoding. So we hard code this config.
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 8e8cc1dd983f8..18420afb27e3c 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -211,17 +211,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes,
It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the
enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To
-use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`:
+use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`:
{% highlight bash %}
-$ PYSPARK_PYTHON=ipython ./bin/pyspark
+$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark
{% endhighlight %}
-You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch
+You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch
the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support:
{% highlight bash %}
-$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark
+$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark
{% endhighlight %}
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 5c21e912ea160..738309c668387 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -494,7 +494,7 @@ methods for creating DStreams from files and Akka actors as input sources.
For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores.
-- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details.
+- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details.
- **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream.
diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py
new file mode 100644
index 0000000000000..40faff0ccc7db
--- /dev/null
+++ b/examples/src/main/python/streaming/hdfs_wordcount.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in new text files created in the given directory
+ Usage: hdfs_wordcount.py
+ is the directory that Spark Streaming will use to find and read new text files.
+
+ To run this on your local machine on directory `localdir`, run this example
+ $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir
+
+ Then create a text file in `localdir` and the words in the file will get counted.
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print >> sys.stderr, "Usage: hdfs_wordcount.py "
+ exit(-1)
+
+ sc = SparkContext(appName="PythonStreamingHDFSWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.textFileStream(sys.argv[1])
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda x: (x, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py
new file mode 100644
index 0000000000000..cfa9c1ff5bfbc
--- /dev/null
+++ b/examples/src/main/python/streaming/network_wordcount.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ Usage: network_wordcount.py
+ and describe the TCP server that Spark Streaming would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: network_wordcount.py "
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py
new file mode 100644
index 0000000000000..18a9a5a452ffb
--- /dev/null
+++ b/examples/src/main/python/streaming/stateful_network_wordcount.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the
+ network every second.
+
+ Usage: stateful_network_wordcount.py
+ and describe the TCP server that Spark Streaming
+ would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
+ localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: stateful_network_wordcount.py "
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+ ssc.checkpoint("checkpoint")
+
+ def updateFunc(new_values, last_sum):
+ return sum(new_values) + (last_sum or 0)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ running_counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .updateStateByKey(updateFunc)
+
+ running_counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
index c4317a6aec798..45527d9382fd0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
@@ -46,17 +46,6 @@ object Analytics extends Logging {
}
val options = mutable.Map(optionsList: _*)
- def pickPartitioner(v: String): PartitionStrategy = {
- // TODO: Use reflection rather than listing all the partitioning strategies here.
- v match {
- case "RandomVertexCut" => RandomVertexCut
- case "EdgePartition1D" => EdgePartition1D
- case "EdgePartition2D" => EdgePartition2D
- case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut
- case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v)
- }
- }
-
val conf = new SparkConf()
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
@@ -67,7 +56,7 @@ object Analytics extends Logging {
sys.exit(1)
}
val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy")
- .map(pickPartitioner(_))
+ .map(PartitionStrategy.fromString(_))
val edgeStorageLevel = options.remove("edgeStorageLevel")
.map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY)
val vertexStorageLevel = options.remove("vertexStorageLevel")
@@ -107,7 +96,7 @@ object Analytics extends Logging {
if (!outFname.isEmpty) {
logWarning("Saving pageranks of pages to " + outFname)
- pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
+ pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname)
}
sc.stop()
@@ -129,7 +118,7 @@ object Analytics extends Logging {
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
val cc = ConnectedComponents.run(graph)
- println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct())
+ println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct())
sc.stop()
case "triangles" =>
@@ -147,7 +136,7 @@ object Analytics extends Logging {
minEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel)
- // TriangleCount requires the graph to be partitioned
+ // TriangleCount requires the graph to be partitioned
.partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache()
val triangles = TriangleCount.run(graph)
println("Triangles: " + triangles.vertices.map {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index fa7a26f17c3ca..ebbd8e0257209 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -176,6 +176,8 @@ private class RandomForest (
timer.stop("findBestSplits")
}
+ baggedInput.unpersist()
+
timer.stop("total")
logInfo("Internal timing for DecisionTree:")
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 8ef2bb1bf6a78..0dbe766b4d917 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -67,8 +67,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val file = new File(tempDir.getPath, "part-00000")
Files.write(lines, file, Charsets.US_ASCII)
val path = tempDir.toURI.toString
@@ -100,7 +99,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))),
LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03))
), 2)
- val tempDir = Files.createTempDir()
+ val tempDir = Utils.createTempDir()
val outputDir = new File(tempDir, "output")
MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString)
val lines = outputDir.listFiles()
@@ -166,7 +165,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
Vectors.sparse(2, Array(1), Array(-1.0)),
Vectors.dense(0.0, 1.0)
), 2)
- val tempDir = Files.createTempDir()
+ val tempDir = Utils.createTempDir()
val outputDir = new File(tempDir, "vectors")
val path = outputDir.toURI.toString
vectors.saveAsTextFile(path)
@@ -181,7 +180,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))),
LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
), 2)
- val tempDir = Files.createTempDir()
+ val tempDir = Utils.createTempDir()
val outputDir = new File(tempDir, "points")
val path = outputDir.toURI.toString
points.saveAsTextFile(path)
diff --git a/pom.xml b/pom.xml
index 31f456a03bb3b..6318a1ba0d933 100644
--- a/pom.xml
+++ b/pom.xml
@@ -227,6 +227,18 @@
false
+
+
+ spark-staging
+ Spring Staging Repository
+ https://oss.sonatype.org/content/repositories/orgspark-project-1085
+
+ true
+
+
+ false
+
+
diff --git a/python/docs/conf.py b/python/docs/conf.py
index 8e6324f058251..e58d97ae6a746 100644
--- a/python/docs/conf.py
+++ b/python/docs/conf.py
@@ -131,7 +131,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+#html_static_path = ['_static']
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
diff --git a/python/docs/epytext.py b/python/docs/epytext.py
index 61d731bff570d..19fefbfc057a4 100644
--- a/python/docs/epytext.py
+++ b/python/docs/epytext.py
@@ -5,7 +5,7 @@
(r"L{([\w.()]+)}", r":class:`\1`"),
(r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
(r"C{([\w.()]+)}", r":class:`\1`"),
- (r"[IBCM]{(.+)}", r"`\1`"),
+ (r"[IBCM]{([^}]+)}", r"`\1`"),
('pyspark.rdd.RDD', 'RDD'),
)
diff --git a/python/docs/index.rst b/python/docs/index.rst
index d66e051b15371..703bef644de28 100644
--- a/python/docs/index.rst
+++ b/python/docs/index.rst
@@ -13,6 +13,7 @@ Contents:
pyspark
pyspark.sql
+ pyspark.streaming
pyspark.mllib
diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst
index a68bd62433085..e81be3b6cb796 100644
--- a/python/docs/pyspark.rst
+++ b/python/docs/pyspark.rst
@@ -7,8 +7,9 @@ Subpackages
.. toctree::
:maxdepth: 1
- pyspark.mllib
pyspark.sql
+ pyspark.streaming
+ pyspark.mllib
Contents
--------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6fb30d65c5edd..89d2e2e5b4a8e 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -29,7 +29,7 @@
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, CompressedSerializer
+ PairDeserializer, CompressedSerializer, AutoBatchedSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
@@ -67,8 +67,8 @@ class SparkContext(object):
_default_batch_size_for_serialized_input = 10
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
- gateway=None):
+ environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
+ gateway=None, jsc=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@@ -83,8 +83,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
:param environment: A dictionary of environment variables to set on
worker nodes.
:param batchSize: The number of Python objects represented as a single
- Java object. Set 1 to disable batching or -1 to use an
- unlimited batch size.
+ Java object. Set 1 to disable batching, 0 to automatically choose
+ the batch size based on object sizes, or -1 to use an unlimited
+ batch size
:param serializer: The serializer for RDDs.
:param conf: A L{SparkConf} object setting Spark properties.
:param gateway: Use an existing gateway and JVM, otherwise a new JVM
@@ -103,20 +104,22 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf)
+ conf, jsc)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf):
+ conf, jsc):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
self._unbatched_serializer = serializer
if batchSize == 1:
self.serializer = self._unbatched_serializer
+ elif batchSize == 0:
+ self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
batchSize)
@@ -151,7 +154,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self.environment[varName] = v
# Create the Java SparkContext through Py4J
- self._jsc = self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index a44a27fd3b6a6..f4cbf31b94fe2 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -44,6 +44,7 @@ def transform(self, word):
"""
:param word: a word
:return: vector representation of word
+
Transforms a word to its vector representation
Note: local use only
@@ -57,6 +58,7 @@ def findSynonyms(self, x, num):
:param x: a word or a vector representation of word
:param num: number of synonyms to find
:return: array of (word, cosineSimilarity)
+
Find synonyms of a word
Note: local use only
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5c20e100e144f..463faf7b6f520 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -25,7 +25,11 @@
from numpy import array, array_equal
if sys.version_info[:2] <= (2, 6):
- import unittest2 as unittest
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
else:
import unittest
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6797d50659a92..e13bab946c44a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2009,7 +2009,7 @@ def countApproxDistinct(self, relativeSD=0.05):
of The Art Cardinality Estimation Algorithm", available
here .
- :param relativeSD Relative accuracy. Smaller values create
+ :param relativeSD: Relative accuracy. Smaller values create
counters that require more space.
It must be greater than 0.000017.
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 099fa54cf2bd7..08a0f0d8ffb3e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -114,6 +114,9 @@ def __ne__(self, other):
def __repr__(self):
return "<%s object>" % self.__class__.__name__
+ def __hash__(self):
+ return hash(str(self))
+
class FramedSerializer(Serializer):
@@ -220,7 +223,7 @@ class AutoBatchedSerializer(BatchedSerializer):
Choose the size of batch automatically based on the size of object
"""
- def __init__(self, serializer, bestSize=1 << 20):
+ def __init__(self, serializer, bestSize=1 << 16):
BatchedSerializer.__init__(self, serializer, -1)
self.bestSize = bestSize
@@ -247,7 +250,7 @@ def __eq__(self, other):
other.serializer == self.serializer)
def __str__(self):
- return "BatchedSerializer<%s>" % str(self.serializer)
+ return "AutoBatchedSerializer<%s>" % str(self.serializer)
class CartesianDeserializer(FramedSerializer):
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index d3d36eb995ab6..b31a82f9b19ac 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -19,14 +19,14 @@
public classes of Spark SQL:
- L{SQLContext}
- Main entry point for SQL functionality.
+ Main entry point for SQL functionality.
- L{SchemaRDD}
- A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
- addition to normal RDD operations, SchemaRDDs also support SQL.
+ A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
+ addition to normal RDD operations, SchemaRDDs also support SQL.
- L{Row}
- A Row of data returned by a Spark SQL query.
+ A Row of data returned by a Spark SQL query.
- L{HiveContext}
- Main entry point for accessing data stored in Apache Hive..
+ Main entry point for accessing data stored in Apache Hive..
"""
import itertools
diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py
new file mode 100644
index 0000000000000..d2644a1d4ffab
--- /dev/null
+++ b/python/pyspark/streaming/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.streaming.context import StreamingContext
+from pyspark.streaming.dstream import DStream
+
+__all__ = ['StreamingContext', 'DStream']
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
new file mode 100644
index 0000000000000..dc9dc41121935
--- /dev/null
+++ b/python/pyspark/streaming/context.py
@@ -0,0 +1,325 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+
+from py4j.java_collections import ListConverter
+from py4j.java_gateway import java_import, JavaObject
+
+from pyspark import RDD, SparkConf
+from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
+from pyspark.context import SparkContext
+from pyspark.storagelevel import StorageLevel
+from pyspark.streaming.dstream import DStream
+from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer
+
+__all__ = ["StreamingContext"]
+
+
+def _daemonize_callback_server():
+ """
+ Hack Py4J to daemonize callback server
+
+ The thread of callback server has daemon=False, it will block the driver
+ from exiting if it's not shutdown. The following code replace `start()`
+ of CallbackServer with a new version, which set daemon=True for this
+ thread.
+
+ Also, it will update the port number (0) with real port
+ """
+ # TODO: create a patch for Py4J
+ import socket
+ import py4j.java_gateway
+ logger = py4j.java_gateway.logger
+ from py4j.java_gateway import Py4JNetworkError
+ from threading import Thread
+
+ def start(self):
+ """Starts the CallbackServer. This method should be called by the
+ client instead of run()."""
+ self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
+ 1)
+ try:
+ self.server_socket.bind((self.address, self.port))
+ if not self.port:
+ # update port with real port
+ self.port = self.server_socket.getsockname()[1]
+ except Exception as e:
+ msg = 'An error occurred while trying to start the callback server: %s' % e
+ logger.exception(msg)
+ raise Py4JNetworkError(msg)
+
+ # Maybe thread needs to be cleanup up?
+ self.thread = Thread(target=self.run)
+ self.thread.daemon = True
+ self.thread.start()
+
+ py4j.java_gateway.CallbackServer.start = start
+
+
+class StreamingContext(object):
+ """
+ Main entry point for Spark Streaming functionality. A StreamingContext
+ represents the connection to a Spark cluster, and can be used to create
+ L{DStream} various input sources. It can be from an existing L{SparkContext}.
+ After creating and transforming DStreams, the streaming computation can
+ be started and stopped using `context.start()` and `context.stop()`,
+ respectively. `context.awaitTransformation()` allows the current thread
+ to wait for the termination of the context by `stop()` or by an exception.
+ """
+ _transformerSerializer = None
+
+ def __init__(self, sparkContext, batchDuration=None, jssc=None):
+ """
+ Create a new StreamingContext.
+
+ @param sparkContext: L{SparkContext} object.
+ @param batchDuration: the time interval (in seconds) at which streaming
+ data will be divided into batches
+ """
+
+ self._sc = sparkContext
+ self._jvm = self._sc._jvm
+ self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
+
+ def _initialize_context(self, sc, duration):
+ self._ensure_initialized()
+ return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
+
+ def _jduration(self, seconds):
+ """
+ Create Duration object given number of seconds
+ """
+ return self._jvm.Duration(int(seconds * 1000))
+
+ @classmethod
+ def _ensure_initialized(cls):
+ SparkContext._ensure_initialized()
+ gw = SparkContext._gateway
+
+ java_import(gw.jvm, "org.apache.spark.streaming.*")
+ java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
+ java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
+
+ # start callback server
+ # getattr will fallback to JVM, so we cannot test by hasattr()
+ if "_callback_server" not in gw.__dict__:
+ _daemonize_callback_server()
+ # use random port
+ gw._start_callback_server(0)
+ # gateway with real port
+ gw._python_proxy_port = gw._callback_server.port
+ # get the GatewayServer object in JVM by ID
+ jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
+ # update the port of CallbackClient with real port
+ gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
+
+ # register serializer for TransformFunction
+ # it happens before creating SparkContext when loading from checkpointing
+ cls._transformerSerializer = TransformFunctionSerializer(
+ SparkContext._active_spark_context, CloudPickleSerializer(), gw)
+
+ @classmethod
+ def getOrCreate(cls, checkpointPath, setupFunc):
+ """
+ Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
+ will be used to create a JavaStreamingContext.
+
+ @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+ @param setupFunc Function to create a new JavaStreamingContext and setup DStreams
+ """
+ # TODO: support checkpoint in HDFS
+ if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
+ ssc = setupFunc()
+ ssc.checkpoint(checkpointPath)
+ return ssc
+
+ cls._ensure_initialized()
+ gw = SparkContext._gateway
+
+ try:
+ jssc = gw.jvm.JavaStreamingContext(checkpointPath)
+ except Exception:
+ print >>sys.stderr, "failed to load StreamingContext from checkpoint"
+ raise
+
+ jsc = jssc.sparkContext()
+ conf = SparkConf(_jconf=jsc.getConf())
+ sc = SparkContext(conf=conf, gateway=gw, jsc=jsc)
+ # update ctx in serializer
+ SparkContext._active_spark_context = sc
+ cls._transformerSerializer.ctx = sc
+ return StreamingContext(sc, None, jssc)
+
+ @property
+ def sparkContext(self):
+ """
+ Return SparkContext which is associated with this StreamingContext.
+ """
+ return self._sc
+
+ def start(self):
+ """
+ Start the execution of the streams.
+ """
+ self._jssc.start()
+
+ def awaitTermination(self, timeout=None):
+ """
+ Wait for the execution to stop.
+ @param timeout: time to wait in seconds
+ """
+ if timeout is None:
+ self._jssc.awaitTermination()
+ else:
+ self._jssc.awaitTermination(int(timeout * 1000))
+
+ def stop(self, stopSparkContext=True, stopGraceFully=False):
+ """
+ Stop the execution of the streams, with option of ensuring all
+ received data has been processed.
+
+ @param stopSparkContext: Stop the associated SparkContext or not
+ @param stopGracefully: Stop gracefully by waiting for the processing
+ of all received data to be completed
+ """
+ self._jssc.stop(stopSparkContext, stopGraceFully)
+ if stopSparkContext:
+ self._sc.stop()
+
+ def remember(self, duration):
+ """
+ Set each DStreams in this context to remember RDDs it generated
+ in the last given duration. DStreams remember RDDs only for a
+ limited duration of time and releases them for garbage collection.
+ This method allows the developer to specify how to long to remember
+ the RDDs (if the developer wishes to query old data outside the
+ DStream computation).
+
+ @param duration: Minimum duration (in seconds) that each DStream
+ should remember its RDDs
+ """
+ self._jssc.remember(self._jduration(duration))
+
+ def checkpoint(self, directory):
+ """
+ Sets the context to periodically checkpoint the DStream operations for master
+ fault-tolerance. The graph will be checkpointed every batch interval.
+
+ @param directory: HDFS-compatible directory where the checkpoint data
+ will be reliably stored
+ """
+ self._jssc.checkpoint(directory)
+
+ def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
+ """
+ Create an input from TCP source hostname:port. Data is received using
+ a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited
+ lines.
+
+ @param hostname: Hostname to connect to for receiving data
+ @param port: Port to connect to for receiving data
+ @param storageLevel: Storage level to use for storing the received objects
+ """
+ jlevel = self._sc._getJavaStorageLevel(storageLevel)
+ return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
+ UTF8Deserializer())
+
+ def textFileStream(self, directory):
+ """
+ Create an input stream that monitors a Hadoop-compatible file system
+ for new files and reads them as text files. Files must be wrriten to the
+ monitored directory by "moving" them from another location within the same
+ file system. File names starting with . are ignored.
+ """
+ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
+
+ def _check_serializers(self, rdds):
+ # make sure they have same serializer
+ if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
+ for i in range(len(rdds)):
+ # reset them to sc.serializer
+ rdds[i] = rdds[i]._reserialize()
+
+ def queueStream(self, rdds, oneAtATime=True, default=None):
+ """
+ Create an input stream from an queue of RDDs or list. In each batch,
+ it will process either one or all of the RDDs returned by the queue.
+
+ NOTE: changes to the queue after the stream is created will not be recognized.
+
+ @param rdds: Queue of RDDs
+ @param oneAtATime: pick one rdd each time or pick all of them once.
+ @param default: The default rdd if no more in rdds
+ """
+ if default and not isinstance(default, RDD):
+ default = self._sc.parallelize(default)
+
+ if not rdds and default:
+ rdds = [rdds]
+
+ if rdds and not isinstance(rdds[0], RDD):
+ rdds = [self._sc.parallelize(input) for input in rdds]
+ self._check_serializers(rdds)
+
+ jrdds = ListConverter().convert([r._jrdd for r in rdds],
+ SparkContext._gateway._gateway_client)
+ queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
+ if default:
+ default = default._reserialize(rdds[0]._jrdd_deserializer)
+ jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
+ else:
+ jdstream = self._jssc.queueStream(queue, oneAtATime)
+ return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
+
+ def transform(self, dstreams, transformFunc):
+ """
+ Create a new DStream in which each RDD is generated by applying
+ a function on RDDs of the DStreams. The order of the JavaRDDs in
+ the transform function parameter will be the same as the order
+ of corresponding DStreams in the list.
+ """
+ jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
+ SparkContext._gateway._gateway_client)
+ # change the final serializer to sc.serializer
+ func = TransformFunction(self._sc,
+ lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
+ *[d._jrdd_deserializer for d in dstreams])
+ jfunc = self._jvm.TransformFunction(func)
+ jdstream = self._jssc.transform(jdstreams, jfunc)
+ return DStream(jdstream, self, self._sc.serializer)
+
+ def union(self, *dstreams):
+ """
+ Create a unified DStream from multiple DStreams of the same
+ type and same slide duration.
+ """
+ if not dstreams:
+ raise ValueError("should have at least one DStream to union")
+ if len(dstreams) == 1:
+ return dstreams[0]
+ if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
+ raise ValueError("All DStreams should have same serializer")
+ if len(set(s._slideDuration for s in dstreams)) > 1:
+ raise ValueError("All DStreams should have same slide duration")
+ first = dstreams[0]
+ jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
+ SparkContext._gateway._gateway_client)
+ return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
new file mode 100644
index 0000000000000..5ae5cf07f0137
--- /dev/null
+++ b/python/pyspark/streaming/dstream.py
@@ -0,0 +1,621 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from itertools import chain, ifilter, imap
+import operator
+import time
+from datetime import datetime
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import RDD
+from pyspark.storagelevel import StorageLevel
+from pyspark.streaming.util import rddToFileName, TransformFunction
+from pyspark.rdd import portable_hash
+from pyspark.resultiterable import ResultIterable
+
+__all__ = ["DStream"]
+
+
+class DStream(object):
+ """
+ A Discretized Stream (DStream), the basic abstraction in Spark Streaming,
+ is a continuous sequence of RDDs (of the same type) representing a
+ continuous stream of data (see L{RDD} in the Spark core documentation
+ for more details on RDDs).
+
+ DStreams can either be created from live data (such as, data from TCP
+ sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be
+ generated by transforming existing DStreams using operations such as
+ `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming
+ program is running, each DStream periodically generates a RDD, either
+ from live data or by transforming the RDD generated by a parent DStream.
+
+ DStreams internally is characterized by a few basic properties:
+ - A list of other DStreams that the DStream depends on
+ - A time interval at which the DStream generates an RDD
+ - A function that is used to generate an RDD after each time interval
+ """
+ def __init__(self, jdstream, ssc, jrdd_deserializer):
+ self._jdstream = jdstream
+ self._ssc = ssc
+ self._sc = ssc._sc
+ self._jrdd_deserializer = jrdd_deserializer
+ self.is_cached = False
+ self.is_checkpointed = False
+
+ def context(self):
+ """
+ Return the StreamingContext associated with this DStream
+ """
+ return self._ssc
+
+ def count(self):
+ """
+ Return a new DStream in which each RDD has a single element
+ generated by counting each RDD of this DStream.
+ """
+ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add)
+
+ def filter(self, f):
+ """
+ Return a new DStream containing only the elements that satisfy predicate.
+ """
+ def func(iterator):
+ return ifilter(f, iterator)
+ return self.mapPartitions(func, True)
+
+ def flatMap(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream by applying a function to all elements of
+ this DStream, and then flattening the results
+ """
+ def func(s, iterator):
+ return chain.from_iterable(imap(f, iterator))
+ return self.mapPartitionsWithIndex(func, preservesPartitioning)
+
+ def map(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream by applying a function to each element of DStream.
+ """
+ def func(iterator):
+ return imap(f, iterator)
+ return self.mapPartitions(func, preservesPartitioning)
+
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream in which each RDD is generated by applying
+ mapPartitions() to each RDDs of this DStream.
+ """
+ def func(s, iterator):
+ return f(iterator)
+ return self.mapPartitionsWithIndex(func, preservesPartitioning)
+
+ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream in which each RDD is generated by applying
+ mapPartitionsWithIndex() to each RDDs of this DStream.
+ """
+ return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning))
+
+ def reduce(self, func):
+ """
+ Return a new DStream in which each RDD has a single element
+ generated by reducing each RDD of this DStream.
+ """
+ return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1])
+
+ def reduceByKey(self, func, numPartitions=None):
+ """
+ Return a new DStream by applying reduceByKey to each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.combineByKey(lambda x: x, func, func, numPartitions)
+
+ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
+ numPartitions=None):
+ """
+ Return a new DStream by applying combineByKey to each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ def func(rdd):
+ return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions)
+ return self.transform(func)
+
+ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
+ """
+ Return a copy of the DStream in which each RDD are partitioned
+ using the specified partitioner.
+ """
+ return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc))
+
+ def foreachRDD(self, func):
+ """
+ Apply a function to each RDD in this DStream.
+ """
+ if func.func_code.co_argcount == 1:
+ old_func = func
+ func = lambda t, rdd: old_func(rdd)
+ jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
+ api = self._ssc._jvm.PythonDStream
+ api.callForeachRDD(self._jdstream, jfunc)
+
+ def pprint(self):
+ """
+ Print the first ten elements of each RDD generated in this DStream.
+ """
+ def takeAndPrint(time, rdd):
+ taken = rdd.take(11)
+ print "-------------------------------------------"
+ print "Time: %s" % time
+ print "-------------------------------------------"
+ for record in taken[:10]:
+ print record
+ if len(taken) > 10:
+ print "..."
+ print
+
+ self.foreachRDD(takeAndPrint)
+
+ def mapValues(self, f):
+ """
+ Return a new DStream by applying a map function to the value of
+ each key-value pairs in this DStream without changing the key.
+ """
+ map_values_fn = lambda (k, v): (k, f(v))
+ return self.map(map_values_fn, preservesPartitioning=True)
+
+ def flatMapValues(self, f):
+ """
+ Return a new DStream by applying a flatmap function to the value
+ of each key-value pairs in this DStream without changing the key.
+ """
+ flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
+ return self.flatMap(flat_map_fn, preservesPartitioning=True)
+
+ def glom(self):
+ """
+ Return a new DStream in which RDD is generated by applying glom()
+ to RDD of this DStream.
+ """
+ def func(iterator):
+ yield list(iterator)
+ return self.mapPartitions(func)
+
+ def cache(self):
+ """
+ Persist the RDDs of this DStream with the default storage level
+ (C{MEMORY_ONLY_SER}).
+ """
+ self.is_cached = True
+ self.persist(StorageLevel.MEMORY_ONLY_SER)
+ return self
+
+ def persist(self, storageLevel):
+ """
+ Persist the RDDs of this DStream with the given storage level
+ """
+ self.is_cached = True
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdstream.persist(javaStorageLevel)
+ return self
+
+ def checkpoint(self, interval):
+ """
+ Enable periodic checkpointing of RDDs of this DStream
+
+ @param interval: time in seconds, after each period of that, generated
+ RDD will be checkpointed
+ """
+ self.is_checkpointed = True
+ self._jdstream.checkpoint(self._ssc._jduration(interval))
+ return self
+
+ def groupByKey(self, numPartitions=None):
+ """
+ Return a new DStream by applying groupByKey on each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transform(lambda rdd: rdd.groupByKey(numPartitions))
+
+ def countByValue(self):
+ """
+ Return a new DStream in which each RDD contains the counts of each
+ distinct value in each RDD of this DStream.
+ """
+ return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count()
+
+ def saveAsTextFiles(self, prefix, suffix=None):
+ """
+ Save each RDD in this DStream as at text file, using string
+ representation of elements.
+ """
+ def saveAsTextFile(t, rdd):
+ path = rddToFileName(prefix, suffix, t)
+ try:
+ rdd.saveAsTextFile(path)
+ except Py4JJavaError as e:
+ # after recovered from checkpointing, the foreachRDD may
+ # be called twice
+ if 'FileAlreadyExistsException' not in str(e):
+ raise
+ return self.foreachRDD(saveAsTextFile)
+
+ # TODO: uncomment this until we have ssc.pickleFileStream()
+ # def saveAsPickleFiles(self, prefix, suffix=None):
+ # """
+ # Save each RDD in this DStream as at binary file, the elements are
+ # serialized by pickle.
+ # """
+ # def saveAsPickleFile(t, rdd):
+ # path = rddToFileName(prefix, suffix, t)
+ # try:
+ # rdd.saveAsPickleFile(path)
+ # except Py4JJavaError as e:
+ # # after recovered from checkpointing, the foreachRDD may
+ # # be called twice
+ # if 'FileAlreadyExistsException' not in str(e):
+ # raise
+ # return self.foreachRDD(saveAsPickleFile)
+
+ def transform(self, func):
+ """
+ Return a new DStream in which each RDD is generated by applying a function
+ on each RDD of this DStream.
+
+ `func` can have one argument of `rdd`, or have two arguments of
+ (`time`, `rdd`)
+ """
+ if func.func_code.co_argcount == 1:
+ oldfunc = func
+ func = lambda t, rdd: oldfunc(rdd)
+ assert func.func_code.co_argcount == 2, "func should take one or two arguments"
+ return TransformedDStream(self, func)
+
+ def transformWith(self, func, other, keepSerializer=False):
+ """
+ Return a new DStream in which each RDD is generated by applying a function
+ on each RDD of this DStream and 'other' DStream.
+
+ `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three
+ arguments of (`time`, `rdd_a`, `rdd_b`)
+ """
+ if func.func_code.co_argcount == 2:
+ oldfunc = func
+ func = lambda t, a, b: oldfunc(a, b)
+ assert func.func_code.co_argcount == 3, "func should take two or three arguments"
+ jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
+ other._jdstream.dstream(), jfunc)
+ jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer
+ return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
+
+ def repartition(self, numPartitions):
+ """
+ Return a new DStream with an increased or decreased level of parallelism.
+ """
+ return self.transform(lambda rdd: rdd.repartition(numPartitions))
+
+ @property
+ def _slideDuration(self):
+ """
+ Return the slideDuration in seconds of this DStream
+ """
+ return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0
+
+ def union(self, other):
+ """
+ Return a new DStream by unifying data of another DStream with this DStream.
+
+ @param other: Another DStream having the same interval (i.e., slideDuration)
+ as this DStream.
+ """
+ if self._slideDuration != other._slideDuration:
+ raise ValueError("the two DStream should have same slide duration")
+ return self.transformWith(lambda a, b: a.union(b), other, True)
+
+ def cogroup(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'cogroup' between RDDs of this
+ DStream and `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other)
+
+ def join(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.join(b, numPartitions), other)
+
+ def leftOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'left outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other)
+
+ def rightOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'right outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other)
+
+ def fullOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'full outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other)
+
+ def _jtime(self, timestamp):
+ """ Convert datetime or unix_timestamp into Time
+ """
+ if isinstance(timestamp, datetime):
+ timestamp = time.mktime(timestamp.timetuple())
+ return self._sc._jvm.Time(long(timestamp * 1000))
+
+ def slice(self, begin, end):
+ """
+ Return all the RDDs between 'begin' to 'end' (both included)
+
+ `begin`, `end` could be datetime.datetime() or unix_timestamp
+ """
+ jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
+ return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds]
+
+ def _validate_window_param(self, window, slide):
+ duration = self._jdstream.dstream().slideDuration().milliseconds()
+ if int(window * 1000) % duration != 0:
+ raise ValueError("windowDuration must be multiple of the slide duration (%d ms)"
+ % duration)
+ if slide and int(slide * 1000) % duration != 0:
+ raise ValueError("slideDuration must be multiple of the slide duration (%d ms)"
+ % duration)
+
+ def window(self, windowDuration, slideDuration=None):
+ """
+ Return a new DStream in which each RDD contains all the elements in seen in a
+ sliding window of time over this DStream.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ """
+ self._validate_window_param(windowDuration, slideDuration)
+ d = self._ssc._jduration(windowDuration)
+ if slideDuration is None:
+ return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer)
+ s = self._ssc._jduration(slideDuration)
+ return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)
+
+ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration):
+ """
+ Return a new DStream in which each RDD has a single element generated by reducing all
+ elements in a sliding window over this DStream.
+
+ if `invReduceFunc` is not None, the reduction is done incrementally
+ using the old window's reduced value :
+ 1. reduce the new values that entered the window (e.g., adding new counts)
+ 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ This is more efficient than `invReduceFunc` is None.
+
+ @param reduceFunc: associative reduce function
+ @param invReduceFunc: inverse reduce function of `reduceFunc`
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ """
+ keyed = self.map(lambda x: (1, x))
+ reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
+ windowDuration, slideDuration, 1)
+ return reduced.map(lambda (k, v): v)
+
+ def countByWindow(self, windowDuration, slideDuration):
+ """
+ Return a new DStream in which each RDD has a single element generated
+ by counting the number of elements in a window over this DStream.
+ windowDuration and slideDuration are as defined in the window() operation.
+
+ This is equivalent to window(windowDuration, slideDuration).count(),
+ but will be more efficient if window is large.
+ """
+ return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub,
+ windowDuration, slideDuration)
+
+ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None):
+ """
+ Return a new DStream in which each RDD contains the count of distinct elements in
+ RDDs in a sliding window over this DStream.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: number of partitions of each RDD in the new DStream.
+ """
+ keyed = self.map(lambda x: (x, 1))
+ counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
+ windowDuration, slideDuration, numPartitions)
+ return counted.filter(lambda (k, v): v > 0).count()
+
+ def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
+ """
+ Return a new DStream by applying `groupByKey` over a sliding window.
+ Similar to `DStream.groupByKey()`, but applies it over a sliding window.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: Number of partitions of each RDD in the new DStream.
+ """
+ ls = self.mapValues(lambda x: [x])
+ grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):],
+ windowDuration, slideDuration, numPartitions)
+ return grouped.mapValues(ResultIterable)
+
+ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None,
+ numPartitions=None, filterFunc=None):
+ """
+ Return a new DStream by applying incremental `reduceByKey` over a sliding window.
+
+ The reduced value of over a new window is calculated using the old window's reduce value :
+ 1. reduce the new values that entered the window (e.g., adding new counts)
+ 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+
+ `invFunc` can be None, then it will reduce all the RDDs in window, could be slower
+ than having `invFunc`.
+
+ @param reduceFunc: associative reduce function
+ @param invReduceFunc: inverse function of `reduceFunc`
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: number of partitions of each RDD in the new DStream.
+ @param filterFunc: function to filter expired key-value pairs;
+ only pairs that satisfy the function are retained
+ set this to null if you do not want to filter
+ """
+ self._validate_window_param(windowDuration, slideDuration)
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ reduced = self.reduceByKey(func, numPartitions)
+
+ def reduceFunc(t, a, b):
+ b = b.reduceByKey(func, numPartitions)
+ r = a.union(b).reduceByKey(func, numPartitions) if a else b
+ if filterFunc:
+ r = r.filter(filterFunc)
+ return r
+
+ def invReduceFunc(t, a, b):
+ b = b.reduceByKey(func, numPartitions)
+ joined = a.leftOuterJoin(b, numPartitions)
+ return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
+
+ jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer)
+ if invReduceFunc:
+ jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer)
+ else:
+ jinvReduceFunc = None
+ if slideDuration is None:
+ slideDuration = self._slideDuration
+ dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
+ jreduceFunc, jinvReduceFunc,
+ self._ssc._jduration(windowDuration),
+ self._ssc._jduration(slideDuration))
+ return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
+
+ def updateStateByKey(self, updateFunc, numPartitions=None):
+ """
+ Return a new "state" DStream where the state for each key is updated by applying
+ the given function on the previous state of the key and the new values of the key.
+
+ @param updateFunc: State update function. If this function returns None, then
+ corresponding state key-value pair will be eliminated.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ def reduceFunc(t, a, b):
+ if a is None:
+ g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
+ else:
+ g = a.cogroup(b, numPartitions)
+ g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
+ state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
+ return state.filter(lambda (k, v): v is not None)
+
+ jreduceFunc = TransformFunction(self._sc, reduceFunc,
+ self._sc.serializer, self._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
+ return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
+
+
+class TransformedDStream(DStream):
+ """
+ TransformedDStream is an DStream generated by an Python function
+ transforming each RDD of an DStream to another RDDs.
+
+ Multiple continuous transformations of DStream can be combined into
+ one transformation.
+ """
+ def __init__(self, prev, func):
+ self._ssc = prev._ssc
+ self._sc = self._ssc._sc
+ self._jrdd_deserializer = self._sc.serializer
+ self.is_cached = False
+ self.is_checkpointed = False
+ self._jdstream_val = None
+
+ if (isinstance(prev, TransformedDStream) and
+ not prev.is_cached and not prev.is_checkpointed):
+ prev_func = prev.func
+ self.func = lambda t, rdd: func(t, prev_func(t, rdd))
+ self.prev = prev.prev
+ else:
+ self.prev = prev
+ self.func = func
+
+ @property
+ def _jdstream(self):
+ if self._jdstream_val is not None:
+ return self._jdstream_val
+
+ jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
+ self._jdstream_val = dstream.asJavaDStream()
+ return self._jdstream_val
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
new file mode 100644
index 0000000000000..a8d876d0fa3b3
--- /dev/null
+++ b/python/pyspark/streaming/tests.py
@@ -0,0 +1,545 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from itertools import chain
+import time
+import operator
+import unittest
+import tempfile
+
+from pyspark.context import SparkConf, SparkContext, RDD
+from pyspark.streaming.context import StreamingContext
+
+
+class PySparkStreamingTestCase(unittest.TestCase):
+
+ timeout = 10 # seconds
+ duration = 1
+
+ def setUp(self):
+ class_name = self.__class__.__name__
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ self.sc = SparkContext(appName=class_name, conf=conf)
+ self.sc.setCheckpointDir("/tmp")
+ # TODO: decrease duration to speed up tests
+ self.ssc = StreamingContext(self.sc, self.duration)
+
+ def tearDown(self):
+ self.ssc.stop()
+
+ def wait_for(self, result, n):
+ start_time = time.time()
+ while len(result) < n and time.time() - start_time < self.timeout:
+ time.sleep(0.01)
+ if len(result) < n:
+ print "timeout after", self.timeout
+
+ def _take(self, dstream, n):
+ """
+ Return the first `n` elements in the stream (will start and stop).
+ """
+ results = []
+
+ def take(_, rdd):
+ if rdd and len(results) < n:
+ results.extend(rdd.take(n - len(results)))
+
+ dstream.foreachRDD(take)
+
+ self.ssc.start()
+ self.wait_for(results, n)
+ return results
+
+ def _collect(self, dstream, n, block=True):
+ """
+ Collect each RDDs into the returned list.
+
+ :return: list, which will have the collected items.
+ """
+ result = []
+
+ def get_output(_, rdd):
+ if rdd and len(result) < n:
+ r = rdd.collect()
+ if r:
+ result.append(r)
+
+ dstream.foreachRDD(get_output)
+
+ if not block:
+ return result
+
+ self.ssc.start()
+ self.wait_for(result, n)
+ return result
+
+ def _test_func(self, input, func, expected, sort=False, input2=None):
+ """
+ @param input: dataset for the test. This should be list of lists.
+ @param func: wrapped function. This function should return PythonDStream object.
+ @param expected: expected output for this testcase.
+ """
+ if not isinstance(input[0], RDD):
+ input = [self.sc.parallelize(d, 1) for d in input]
+ input_stream = self.ssc.queueStream(input)
+ if input2 and not isinstance(input2[0], RDD):
+ input2 = [self.sc.parallelize(d, 1) for d in input2]
+ input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
+
+ # Apply test function to stream.
+ if input2:
+ stream = func(input_stream, input_stream2)
+ else:
+ stream = func(input_stream)
+
+ result = self._collect(stream, len(expected))
+ if sort:
+ self._sort_result_based_on_key(result)
+ self._sort_result_based_on_key(expected)
+ self.assertEqual(expected, result)
+
+ def _sort_result_based_on_key(self, outputs):
+ """Sort the list based on first value."""
+ for output in outputs:
+ output.sort(key=lambda x: x[0])
+
+
+class BasicOperationTests(PySparkStreamingTestCase):
+
+ def test_map(self):
+ """Basic operation test for DStream.map."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.map(str)
+ expected = map(lambda x: map(str, x), input)
+ self._test_func(input, func, expected)
+
+ def test_flatMap(self):
+ """Basic operation test for DStream.faltMap."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.flatMap(lambda x: (x, x * 2))
+ expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))),
+ input)
+ self._test_func(input, func, expected)
+
+ def test_filter(self):
+ """Basic operation test for DStream.filter."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.filter(lambda x: x % 2 == 0)
+ expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input)
+ self._test_func(input, func, expected)
+
+ def test_count(self):
+ """Basic operation test for DStream.count."""
+ input = [range(5), range(10), range(20)]
+
+ def func(dstream):
+ return dstream.count()
+ expected = map(lambda x: [len(x)], input)
+ self._test_func(input, func, expected)
+
+ def test_reduce(self):
+ """Basic operation test for DStream.reduce."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.reduce(operator.add)
+ expected = map(lambda x: [reduce(operator.add, x)], input)
+ self._test_func(input, func, expected)
+
+ def test_reduceByKey(self):
+ """Basic operation test for DStream.reduceByKey."""
+ input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
+ [("", 1), ("", 1), ("", 1), ("", 1)],
+ [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
+
+ def func(dstream):
+ return dstream.reduceByKey(operator.add)
+ expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_mapValues(self):
+ """Basic operation test for DStream.mapValues."""
+ input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+ [("", 4), (1, 1), (2, 2), (3, 3)],
+ [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+ def func(dstream):
+ return dstream.mapValues(lambda x: x + 10)
+ expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
+ [("", 14), (1, 11), (2, 12), (3, 13)],
+ [(1, 11), (2, 11), (3, 11), (4, 11)]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_flatMapValues(self):
+ """Basic operation test for DStream.flatMapValues."""
+ input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+ [("", 4), (1, 1), (2, 1), (3, 1)],
+ [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+ def func(dstream):
+ return dstream.flatMapValues(lambda x: (x, x + 10))
+ expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
+ ("c", 1), ("c", 11), ("d", 1), ("d", 11)],
+ [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
+ [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
+ self._test_func(input, func, expected)
+
+ def test_glom(self):
+ """Basic operation test for DStream.glom."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ return dstream.glom()
+ expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
+ self._test_func(rdds, func, expected)
+
+ def test_mapPartitions(self):
+ """Basic operation test for DStream.mapPartitions."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ def f(iterator):
+ yield sum(iterator)
+ return dstream.mapPartitions(f)
+ expected = [[3, 7], [11, 15], [19, 23]]
+ self._test_func(rdds, func, expected)
+
+ def test_countByValue(self):
+ """Basic operation test for DStream.countByValue."""
+ input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]]
+
+ def func(dstream):
+ return dstream.countByValue()
+ expected = [[4], [4], [3]]
+ self._test_func(input, func, expected)
+
+ def test_groupByKey(self):
+ """Basic operation test for DStream.groupByKey."""
+ input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+ def func(dstream):
+ return dstream.groupByKey().mapValues(list)
+
+ expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
+ [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
+ [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_combineByKey(self):
+ """Basic operation test for DStream.combineByKey."""
+ input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+ def func(dstream):
+ def add(a, b):
+ return a + str(b)
+ return dstream.combineByKey(str, add, add)
+ expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
+ [(1, "111"), (2, "11"), (3, "1")],
+ [("a", "11"), ("b", "1"), ("", "111")]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_repartition(self):
+ input = [range(1, 5), range(5, 9)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ return dstream.repartition(1).glom()
+ expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
+ self._test_func(rdds, func, expected)
+
+ def test_union(self):
+ input1 = [range(3), range(5), range(6)]
+ input2 = [range(3, 6), range(5, 6)]
+
+ def func(d1, d2):
+ return d1.union(d2)
+
+ expected = [range(6), range(6), range(6)]
+ self._test_func(input1, func, expected, input2=input2)
+
+ def test_cogroup(self):
+ input = [[(1, 1), (2, 1), (3, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
+ input2 = [[(1, 2)],
+ [(4, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
+
+ def func(d1, d2):
+ return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
+
+ expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
+ [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
+ [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
+ self._test_func(input, func, expected, sort=True, input2=input2)
+
+ def test_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.join(b)
+
+ expected = [[('b', (2, 3))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_left_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.leftOuterJoin(b)
+
+ expected = [[('a', (1, None)), ('b', (2, 3))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_right_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.rightOuterJoin(b)
+
+ expected = [[('b', (2, 3)), ('c', (None, 4))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_full_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.fullOuterJoin(b)
+
+ expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_update_state_by_key(self):
+
+ def updater(vs, s):
+ if not s:
+ s = []
+ s.extend(vs)
+ return s
+
+ input = [[('k', i)] for i in range(5)]
+
+ def func(dstream):
+ return dstream.updateStateByKey(updater)
+
+ expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
+ expected = [[('k', v)] for v in expected]
+ self._test_func(input, func, expected)
+
+
+class WindowFunctionTests(PySparkStreamingTestCase):
+
+ timeout = 20
+
+ def test_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5)]
+
+ def func(dstream):
+ return dstream.window(3, 1).count()
+
+ expected = [[1], [3], [6], [9], [12], [9], [5]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5)]
+
+ def func(dstream):
+ return dstream.countByWindow(3, 1)
+
+ expected = [[1], [3], [6], [9], [12], [9], [5]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_window_large(self):
+ input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+ def func(dstream):
+ return dstream.countByWindow(5, 1)
+
+ expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_value_and_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+ def func(dstream):
+ return dstream.countByValueAndWindow(5, 1)
+
+ expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
+ self._test_func(input, func, expected)
+
+ def test_group_by_key_and_window(self):
+ input = [[('a', i)] for i in range(5)]
+
+ def func(dstream):
+ return dstream.groupByKeyAndWindow(3, 1).mapValues(list)
+
+ expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
+ [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
+ self._test_func(input, func, expected)
+
+ def test_reduce_by_invalid_window(self):
+ input1 = [range(3), range(5), range(1), range(6)]
+ d1 = self.ssc.queueStream(input1)
+ self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
+ self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
+
+
+class StreamingContextTests(PySparkStreamingTestCase):
+
+ duration = 0.1
+
+ def _add_input_stream(self):
+ inputs = map(lambda x: range(1, x), range(101))
+ stream = self.ssc.queueStream(inputs)
+ self._collect(stream, 1, block=False)
+
+ def test_stop_only_streaming_context(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop(False)
+ self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
+
+ def test_stop_multiple_times(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop()
+ self.ssc.stop()
+
+ def test_queue_stream(self):
+ input = [range(i + 1) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ result = self._collect(dstream, 3)
+ self.assertEqual(input, result)
+
+ def test_text_file_stream(self):
+ d = tempfile.mkdtemp()
+ self.ssc = StreamingContext(self.sc, self.duration)
+ dstream2 = self.ssc.textFileStream(d).map(int)
+ result = self._collect(dstream2, 2, block=False)
+ self.ssc.start()
+ for name in ('a', 'b'):
+ time.sleep(1)
+ with open(os.path.join(d, name), "w") as f:
+ f.writelines(["%d\n" % i for i in range(10)])
+ self.wait_for(result, 2)
+ self.assertEqual([range(10), range(10)], result)
+
+ def test_union(self):
+ input = [range(i + 1) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ dstream2 = self.ssc.queueStream(input)
+ dstream3 = self.ssc.union(dstream, dstream2)
+ result = self._collect(dstream3, 3)
+ expected = [i * 2 for i in input]
+ self.assertEqual(expected, result)
+
+ def test_transform(self):
+ dstream1 = self.ssc.queueStream([[1]])
+ dstream2 = self.ssc.queueStream([[2]])
+ dstream3 = self.ssc.queueStream([[3]])
+
+ def func(rdds):
+ rdd1, rdd2, rdd3 = rdds
+ return rdd2.union(rdd3).union(rdd1)
+
+ dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
+
+ self.assertEqual([2, 3, 1], self._take(dstream, 3))
+
+
+class CheckpointTests(PySparkStreamingTestCase):
+
+ def setUp(self):
+ pass
+
+ def test_get_or_create(self):
+ inputd = tempfile.mkdtemp()
+ outputd = tempfile.mkdtemp() + "/"
+
+ def updater(vs, s):
+ return sum(vs, s or 0)
+
+ def setup():
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ sc = SparkContext(conf=conf)
+ ssc = StreamingContext(sc, 0.5)
+ dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
+ wc = dstream.updateStateByKey(updater)
+ wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
+ wc.checkpoint(.5)
+ return ssc
+
+ cpd = tempfile.mkdtemp("test_streaming_cps")
+ self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc.start()
+
+ def check_output(n):
+ while not os.listdir(outputd):
+ time.sleep(0.1)
+ time.sleep(1) # make sure mtime is larger than the previous one
+ with open(os.path.join(inputd, str(n)), 'w') as f:
+ f.writelines(["%d\n" % i for i in range(10)])
+
+ while True:
+ p = os.path.join(outputd, max(os.listdir(outputd)))
+ if '_SUCCESS' not in os.listdir(p):
+ # not finished
+ time.sleep(0.01)
+ continue
+ ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
+ d = ordd.values().map(int).collect()
+ if not d:
+ time.sleep(0.01)
+ continue
+ self.assertEqual(10, len(d))
+ s = set(d)
+ self.assertEqual(1, len(s))
+ m = s.pop()
+ if n > m:
+ continue
+ self.assertEqual(n, m)
+ break
+
+ check_output(1)
+ check_output(2)
+ ssc.stop(True, True)
+
+ time.sleep(1)
+ self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc.start()
+ check_output(3)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
new file mode 100644
index 0000000000000..86ee5aa04f252
--- /dev/null
+++ b/python/pyspark/streaming/util.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from datetime import datetime
+import traceback
+
+from pyspark import SparkContext, RDD
+
+
+class TransformFunction(object):
+ """
+ This class wraps a function RDD[X] -> RDD[Y] that was passed to
+ DStream.transform(), allowing it to be called from Java via Py4J's
+ callback server.
+
+ Java calls this function with a sequence of JavaRDDs and this function
+ returns a single JavaRDD pointer back to Java.
+ """
+ _emptyRDD = None
+
+ def __init__(self, ctx, func, *deserializers):
+ self.ctx = ctx
+ self.func = func
+ self.deserializers = deserializers
+
+ def call(self, milliseconds, jrdds):
+ try:
+ if self.ctx is None:
+ self.ctx = SparkContext._active_spark_context
+ if not self.ctx or not self.ctx._jsc:
+ # stopped
+ return
+
+ # extend deserializers with the first one
+ sers = self.deserializers
+ if len(sers) < len(jrdds):
+ sers += (sers[0],) * (len(jrdds) - len(sers))
+
+ rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
+ for jrdd, ser in zip(jrdds, sers)]
+ t = datetime.fromtimestamp(milliseconds / 1000.0)
+ r = self.func(t, *rdds)
+ if r:
+ return r._jrdd
+ except Exception:
+ traceback.print_exc()
+
+ def __repr__(self):
+ return "TransformFunction(%s)" % self.func
+
+ class Java:
+ implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction']
+
+
+class TransformFunctionSerializer(object):
+ """
+ This class implements a serializer for PythonTransformFunction Java
+ objects.
+
+ This is necessary because the Java PythonTransformFunction objects are
+ actually Py4J references to Python objects and thus are not directly
+ serializable. When Java needs to serialize a PythonTransformFunction,
+ it uses this class to invoke Python, which returns the serialized function
+ as a byte array.
+ """
+ def __init__(self, ctx, serializer, gateway=None):
+ self.ctx = ctx
+ self.serializer = serializer
+ self.gateway = gateway or self.ctx._gateway
+ self.gateway.jvm.PythonDStream.registerSerializer(self)
+
+ def dumps(self, id):
+ try:
+ func = self.gateway.gateway_property.pool[id]
+ return bytearray(self.serializer.dumps((func.func, func.deserializers)))
+ except Exception:
+ traceback.print_exc()
+
+ def loads(self, bytes):
+ try:
+ f, deserializers = self.serializer.loads(str(bytes))
+ return TransformFunction(self.ctx, f, *deserializers)
+ except Exception:
+ traceback.print_exc()
+
+ def __repr__(self):
+ return "TransformFunctionSerializer(%s)" % self.serializer
+
+ class Java:
+ implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
+
+
+def rddToFileName(prefix, suffix, timestamp):
+ """
+ Return string prefix-time(.suffix)
+
+ >>> rddToFileName("spark", None, 12345678910)
+ 'spark-12345678910'
+ >>> rddToFileName("spark", "tmp", 12345678910)
+ 'spark-12345678910.tmp'
+ """
+ if isinstance(timestamp, datetime):
+ seconds = time.mktime(timestamp.timetuple())
+ timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
+ if suffix is None:
+ return prefix + "-" + str(timestamp)
+ else:
+ return prefix + "-" + str(timestamp) + "." + suffix
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7f05d48ade2b3..ceab57464f013 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -34,7 +34,11 @@
from platform import python_implementation
if sys.version_info[:2] <= (2, 6):
- import unittest2 as unittest
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
else:
import unittest
diff --git a/python/run-tests b/python/run-tests
index 63395f72788f9..80acd002ab7eb 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -25,16 +25,17 @@ FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)"
cd "$FWDIR/python"
FAILED=0
+LOG_FILE=unit-tests.log
-rm -f unit-tests.log
+rm -f $LOG_FILE
# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL
rm -rf metastore warehouse
function run_test() {
- echo "Running test: $1"
+ echo "Running test: $1" | tee -a $LOG_FILE
- SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+ SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE
FAILED=$((PIPESTATUS[0]||$FAILED))
@@ -80,6 +81,12 @@ function run_mllib_tests() {
run_test "pyspark/mllib/tests.py"
}
+function run_streaming_tests() {
+ echo "Run streaming tests ..."
+ run_test "pyspark/streaming/util.py"
+ run_test "pyspark/streaming/tests.py"
+}
+
echo "Running PySpark tests. Output is in python/unit-tests.log."
export PYSPARK_PYTHON="python"
@@ -95,6 +102,7 @@ $PYSPARK_PYTHON --version
run_core_tests
run_sql_tests
run_mllib_tests
+run_streaming_tests
# Try to test with PyPy
if [ $(which pypy) ]; then
@@ -104,6 +112,7 @@ if [ $(which pypy) ]; then
run_core_tests
run_sql_tests
+ run_streaming_tests
fi
if [[ $FAILED == 0 ]]; then
diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
index 3e2ee7541f40d..6a79e76a34db8 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala
@@ -23,8 +23,6 @@ import java.net.{URL, URLClassLoader}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
-import com.google.common.io.Files
-
import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark.util.Utils
@@ -39,10 +37,8 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
override def beforeAll() {
super.beforeAll()
- tempDir1 = Files.createTempDir()
- tempDir1.deleteOnExit()
- tempDir2 = Files.createTempDir()
- tempDir2.deleteOnExit()
+ tempDir1 = Utils.createTempDir()
+ tempDir2 = Utils.createTempDir()
url1 = "file://" + tempDir1
urls2 = List(tempDir2.toURI.toURL).toArray
childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1"))
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index c8763eb277052..91c9c52c3c98a 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -22,7 +22,6 @@ import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.commons.lang3.StringEscapeUtils
@@ -190,8 +189,7 @@ class ReplSuite extends FunSuite {
}
test("interacting with files") {
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n")
out.write("What's up?\n")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
new file mode 100644
index 0000000000000..04467342e6ab5
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SparkSQLParser.scala
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst
+
+import scala.language.implicitConversions
+import scala.util.parsing.combinator.lexical.StdLexical
+import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+import scala.util.parsing.combinator.{PackratParsers, RegexParsers}
+import scala.util.parsing.input.CharArrayReader.EofCh
+
+import org.apache.spark.sql.catalyst.plans.logical._
+
+private[sql] abstract class AbstractSparkSQLParser
+ extends StandardTokenParsers with PackratParsers {
+
+ def apply(input: String): LogicalPlan = phrase(start)(new lexical.Scanner(input)) match {
+ case Success(plan, _) => plan
+ case failureOrError => sys.error(failureOrError.toString)
+ }
+
+ protected case class Keyword(str: String)
+
+ protected def start: Parser[LogicalPlan]
+
+ // Returns the whole input string
+ protected lazy val wholeInput: Parser[String] = new Parser[String] {
+ def apply(in: Input): ParseResult[String] =
+ Success(in.source.toString, in.drop(in.source.length()))
+ }
+
+ // Returns the rest of the input string that are not parsed yet
+ protected lazy val restInput: Parser[String] = new Parser[String] {
+ def apply(in: Input): ParseResult[String] =
+ Success(
+ in.source.subSequence(in.offset, in.source.length()).toString,
+ in.drop(in.source.length()))
+ }
+}
+
+class SqlLexical(val keywords: Seq[String]) extends StdLexical {
+ case class FloatLit(chars: String) extends Token {
+ override def toString = chars
+ }
+
+ reserved ++= keywords.flatMap(w => allCaseVersions(w))
+
+ delimiters += (
+ "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
+ ",", ";", "%", "{", "}", ":", "[", "]", "."
+ )
+
+ override lazy val token: Parser[Token] =
+ ( identChar ~ (identChar | digit).* ^^
+ { case first ~ rest => processIdent((first :: rest).mkString) }
+ | rep1(digit) ~ ('.' ~> digit.*).? ^^ {
+ case i ~ None => NumericLit(i.mkString)
+ case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString)
+ }
+ | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^
+ { case chars => StringLit(chars mkString "") }
+ | EofCh ^^^ EOF
+ | '\'' ~> failure("unclosed string literal")
+ | '"' ~> failure("unclosed string literal")
+ | delim
+ | failure("illegal character")
+ )
+
+ override def identChar = letter | elem('_')
+
+ override def whitespace: Parser[Any] =
+ ( whitespaceChar
+ | '/' ~ '*' ~ comment
+ | '/' ~ '/' ~ chrExcept(EofCh, '\n').*
+ | '#' ~ chrExcept(EofCh, '\n').*
+ | '-' ~ '-' ~ chrExcept(EofCh, '\n').*
+ | '/' ~ '*' ~ failure("unclosed comment")
+ ).*
+
+ /** Generate all variations of upper and lower case of a given string */
+ def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
+ if (s == "") {
+ Stream(prefix)
+ } else {
+ allCaseVersions(s.tail, prefix + s.head.toLower) ++
+ allCaseVersions(s.tail, prefix + s.head.toUpper)
+ }
+ }
+}
+
+/**
+ * The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
+ * dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
+ *
+ * @param fallback A function that parses an input string to a logical plan
+ */
+private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
+
+ // A parser for the key-value part of the "SET [key = [value ]]" syntax
+ private object SetCommandParser extends RegexParsers {
+ private val key: Parser[String] = "(?m)[^=]+".r
+
+ private val value: Parser[String] = "(?m).*$".r
+
+ private val pair: Parser[LogicalPlan] =
+ (key ~ ("=".r ~> value).?).? ^^ {
+ case None => SetCommand(None)
+ case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)))
+ }
+
+ def apply(input: String): LogicalPlan = parseAll(pair, input) match {
+ case Success(plan, _) => plan
+ case x => sys.error(x.toString)
+ }
+ }
+
+ protected val AS = Keyword("AS")
+ protected val CACHE = Keyword("CACHE")
+ protected val LAZY = Keyword("LAZY")
+ protected val SET = Keyword("SET")
+ protected val TABLE = Keyword("TABLE")
+ protected val SOURCE = Keyword("SOURCE")
+ protected val UNCACHE = Keyword("UNCACHE")
+
+ protected implicit def asParser(k: Keyword): Parser[String] =
+ lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
+
+ private val reservedWords: Seq[String] =
+ this
+ .getClass
+ .getMethods
+ .filter(_.getReturnType == classOf[Keyword])
+ .map(_.invoke(this).asInstanceOf[Keyword].str)
+
+ override val lexical = new SqlLexical(reservedWords)
+
+ override protected lazy val start: Parser[LogicalPlan] =
+ cache | uncache | set | shell | source | others
+
+ private lazy val cache: Parser[LogicalPlan] =
+ CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
+ case isLazy ~ tableName ~ plan =>
+ CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
+ }
+
+ private lazy val uncache: Parser[LogicalPlan] =
+ UNCACHE ~ TABLE ~> ident ^^ {
+ case tableName => UncacheTableCommand(tableName)
+ }
+
+ private lazy val set: Parser[LogicalPlan] =
+ SET ~> restInput ^^ {
+ case input => SetCommandParser(input)
+ }
+
+ private lazy val shell: Parser[LogicalPlan] =
+ "!" ~> restInput ^^ {
+ case input => ShellCommand(input.trim)
+ }
+
+ private lazy val source: Parser[LogicalPlan] =
+ SOURCE ~> restInput ^^ {
+ case input => SourceCommand(input.trim)
+ }
+
+ private lazy val others: Parser[LogicalPlan] =
+ wholeInput ^^ {
+ case input => fallback(input)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 854b5b461bdc8..b4d606d37e732 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -18,10 +18,6 @@
package org.apache.spark.sql.catalyst
import scala.language.implicitConversions
-import scala.util.parsing.combinator.lexical.StdLexical
-import scala.util.parsing.combinator.syntactical.StandardTokenParsers
-import scala.util.parsing.combinator.PackratParsers
-import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
@@ -39,31 +35,7 @@ import org.apache.spark.sql.catalyst.types._
* This is currently included mostly for illustrative purposes. Users wanting more complete support
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
-class SqlParser extends StandardTokenParsers with PackratParsers {
-
- def apply(input: String): LogicalPlan = {
- // Special-case out set commands since the value fields can be
- // complex to handle without RegexParsers. Also this approach
- // is clearer for the several possible cases of set commands.
- if (input.trim.toLowerCase.startsWith("set")) {
- input.trim.drop(3).split("=", 2).map(_.trim) match {
- case Array("") => // "set"
- SetCommand(None, None)
- case Array(key) => // "set key"
- SetCommand(Some(key), None)
- case Array(key, value) => // "set key=value"
- SetCommand(Some(key), Some(value))
- }
- } else {
- phrase(query)(new lexical.Scanner(input)) match {
- case Success(r, x) => r
- case x => sys.error(x.toString)
- }
- }
- }
-
- protected case class Keyword(str: String)
-
+class SqlParser extends AbstractSparkSQLParser {
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
@@ -77,10 +49,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val BETWEEN = Keyword("BETWEEN")
protected val BY = Keyword("BY")
protected val CACHE = Keyword("CACHE")
+ protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
+ protected val ELSE = Keyword("ELSE")
+ protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
@@ -97,7 +72,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val IS = Keyword("IS")
protected val JOIN = Keyword("JOIN")
protected val LAST = Keyword("LAST")
- protected val LAZY = Keyword("LAZY")
protected val LEFT = Keyword("LEFT")
protected val LIKE = Keyword("LIKE")
protected val LIMIT = Keyword("LIMIT")
@@ -122,16 +96,18 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val SUBSTRING = Keyword("SUBSTRING")
protected val SUM = Keyword("SUM")
protected val TABLE = Keyword("TABLE")
+ protected val THEN = Keyword("THEN")
protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val TRUE = Keyword("TRUE")
- protected val UNCACHE = Keyword("UNCACHE")
protected val UNION = Keyword("UNION")
protected val UPPER = Keyword("UPPER")
+ protected val WHEN = Keyword("WHEN")
protected val WHERE = Keyword("WHERE")
// Use reflection to find the reserved words defined in this class.
protected val reservedWords =
- this.getClass
+ this
+ .getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)
@@ -145,86 +121,68 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
}
}
- protected lazy val query: Parser[LogicalPlan] = (
- select * (
- UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
- INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) } |
- EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} |
- UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
+ protected lazy val start: Parser[LogicalPlan] =
+ ( select *
+ ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) }
+ | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) }
+ | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)}
+ | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
)
- | insert | cache | unCache
- )
+ | insert
+ )
protected lazy val select: Parser[LogicalPlan] =
- SELECT ~> opt(DISTINCT) ~ projections ~
- opt(from) ~ opt(filter) ~
- opt(grouping) ~
- opt(having) ~
- opt(orderBy) ~
- opt(limit) <~ opt(";") ^^ {
- case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
- val base = r.getOrElse(NoRelation)
- val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
- val withProjection =
- g.map {g =>
- Aggregate(g, assignAliases(p), withFilter)
- }.getOrElse(Project(assignAliases(p), withFilter))
- val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
- val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
- val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
- val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder)
- withLimit
- }
+ SELECT ~> DISTINCT.? ~
+ repsep(projection, ",") ~
+ (FROM ~> relations).? ~
+ (WHERE ~> expression).? ~
+ (GROUP ~ BY ~> rep1sep(expression, ",")).? ~
+ (HAVING ~> expression).? ~
+ (ORDER ~ BY ~> ordering).? ~
+ (LIMIT ~> expression).? ^^ {
+ case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
+ val base = r.getOrElse(NoRelation)
+ val withFilter = f.map(f => Filter(f, base)).getOrElse(base)
+ val withProjection = g
+ .map(Aggregate(_, assignAliases(p), withFilter))
+ .getOrElse(Project(assignAliases(p), withFilter))
+ val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
+ val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
+ val withOrder = o.map(Sort(_, withHaving)).getOrElse(withHaving)
+ val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder)
+ withLimit
+ }
protected lazy val insert: Parser[LogicalPlan] =
- INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ {
- case o ~ r ~ s =>
- val overwrite: Boolean = o.getOrElse("") == "OVERWRITE"
- InsertIntoTable(r, Map[String, Option[String]](), s, overwrite)
- }
-
- protected lazy val cache: Parser[LogicalPlan] =
- CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> select) <~ opt(";") ^^ {
- case isLazy ~ tableName ~ plan =>
- CacheTableCommand(tableName, plan, isLazy.isDefined)
- }
-
- protected lazy val unCache: Parser[LogicalPlan] =
- UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ {
- case tableName => UncacheTableCommand(tableName)
+ INSERT ~> OVERWRITE.? ~ (INTO ~> relation) ~ select ^^ {
+ case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o.isDefined)
}
- protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",")
-
protected lazy val projection: Parser[Expression] =
- expression ~ (opt(AS) ~> opt(ident)) ^^ {
- case e ~ None => e
- case e ~ Some(a) => Alias(e, a)()
+ expression ~ (AS.? ~> ident.?) ^^ {
+ case e ~ a => a.fold(e)(Alias(e, _)())
}
- protected lazy val from: Parser[LogicalPlan] = FROM ~> relations
-
- protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation
-
// Based very loosely on the MySQL Grammar.
// http://dev.mysql.com/doc/refman/5.0/en/join.html
protected lazy val relations: Parser[LogicalPlan] =
- relation ~ "," ~ relation ^^ { case r1 ~ _ ~ r2 => Join(r1, r2, Inner, None) } |
- relation
+ ( relation ~ ("," ~> relation) ^^ { case r1 ~ r2 => Join(r1, r2, Inner, None) }
+ | relation
+ )
protected lazy val relation: Parser[LogicalPlan] =
- joinedRelation |
- relationFactor
+ joinedRelation | relationFactor
protected lazy val relationFactor: Parser[LogicalPlan] =
- ident ~ (opt(AS) ~> opt(ident)) ^^ {
- case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
- } |
- "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) }
+ ( ident ~ (opt(AS) ~> opt(ident)) ^^ {
+ case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
+ }
+ | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
+ )
protected lazy val joinedRelation: Parser[LogicalPlan] =
- relationFactor ~ opt(joinType) ~ JOIN ~ relationFactor ~ opt(joinConditions) ^^ {
- case r1 ~ jt ~ _ ~ r2 ~ cond =>
+ relationFactor ~ joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.? ^^ {
+ case r1 ~ jt ~ r2 ~ cond =>
Join(r1, r2, joinType = jt.getOrElse(Inner), cond)
}
@@ -232,151 +190,145 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
ON ~> expression
protected lazy val joinType: Parser[JoinType] =
- INNER ^^^ Inner |
- LEFT ~ SEMI ^^^ LeftSemi |
- LEFT ~ opt(OUTER) ^^^ LeftOuter |
- RIGHT ~ opt(OUTER) ^^^ RightOuter |
- FULL ~ opt(OUTER) ^^^ FullOuter
-
- protected lazy val filter: Parser[Expression] = WHERE ~ expression ^^ { case _ ~ e => e }
-
- protected lazy val orderBy: Parser[Seq[SortOrder]] =
- ORDER ~> BY ~> ordering
+ ( INNER ^^^ Inner
+ | LEFT ~ SEMI ^^^ LeftSemi
+ | LEFT ~ OUTER.? ^^^ LeftOuter
+ | RIGHT ~ OUTER.? ^^^ RightOuter
+ | FULL ~ OUTER.? ^^^ FullOuter
+ )
protected lazy val ordering: Parser[Seq[SortOrder]] =
- rep1sep(singleOrder, ",") |
- rep1sep(expression, ",") ~ opt(direction) ^^ {
- case exps ~ None => exps.map(SortOrder(_, Ascending))
- case exps ~ Some(d) => exps.map(SortOrder(_, d))
- }
+ ( rep1sep(singleOrder, ",")
+ | rep1sep(expression, ",") ~ direction.? ^^ {
+ case exps ~ d => exps.map(SortOrder(_, d.getOrElse(Ascending)))
+ }
+ )
protected lazy val singleOrder: Parser[SortOrder] =
- expression ~ direction ^^ { case e ~ o => SortOrder(e,o) }
+ expression ~ direction ^^ { case e ~ o => SortOrder(e, o) }
protected lazy val direction: Parser[SortDirection] =
- ASC ^^^ Ascending |
- DESC ^^^ Descending
-
- protected lazy val grouping: Parser[Seq[Expression]] =
- GROUP ~> BY ~> rep1sep(expression, ",")
-
- protected lazy val having: Parser[Expression] =
- HAVING ~> expression
-
- protected lazy val limit: Parser[Expression] =
- LIMIT ~> expression
+ ( ASC ^^^ Ascending
+ | DESC ^^^ Descending
+ )
- protected lazy val expression: Parser[Expression] = orExpression
+ protected lazy val expression: Parser[Expression] =
+ orExpression
protected lazy val orExpression: Parser[Expression] =
- andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1,e2) })
+ andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) })
protected lazy val andExpression: Parser[Expression] =
- comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1,e2) })
+ comparisonExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) })
protected lazy val comparisonExpression: Parser[Expression] =
- termExpression ~ "=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => EqualTo(e1, e2) } |
- termExpression ~ "<" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThan(e1, e2) } |
- termExpression ~ "<=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => LessThanOrEqual(e1, e2) } |
- termExpression ~ ">" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThan(e1, e2) } |
- termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } |
- termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } |
- termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(EqualTo(e1, e2)) } |
- termExpression ~ BETWEEN ~ termExpression ~ AND ~ termExpression ^^ {
- case e ~ _ ~ el ~ _ ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu))
- } |
- termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
- termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
- termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } |
- termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ {
- case e1 ~ _ ~ _ ~ e2 => In(e1, e2)
- } |
- termExpression ~ NOT ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ {
- case e1 ~ _ ~ _ ~ _ ~ e2 => Not(In(e1, e2))
- } |
- termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) } |
- termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) } |
- NOT ~> termExpression ^^ {e => Not(e)} |
- termExpression
+ ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) }
+ | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) }
+ | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) }
+ | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) }
+ | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) }
+ | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
+ | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
+ | termExpression ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ {
+ case e ~ el ~ eu => And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu))
+ }
+ | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
+ | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
+ | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) }
+ | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
+ case e1 ~ e2 => In(e1, e2)
+ }
+ | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
+ case e1 ~ e2 => Not(In(e1, e2))
+ }
+ | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) }
+ | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) }
+ | NOT ~> termExpression ^^ {e => Not(e)}
+ | termExpression
+ )
protected lazy val termExpression: Parser[Expression] =
- productExpression * (
- "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1,e2) } |
- "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1,e2) } )
+ productExpression *
+ ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) }
+ | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) }
+ )
protected lazy val productExpression: Parser[Expression] =
- baseExpression * (
- "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1,e2) } |
- "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1,e2) } |
- "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1,e2) }
- )
+ baseExpression *
+ ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) }
+ | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) }
+ | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) }
+ )
protected lazy val function: Parser[Expression] =
- SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } |
- SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } |
- COUNT ~> "(" ~ "*" <~ ")" ^^ { case _ => Count(Literal(1)) } |
- COUNT ~> "(" ~ expression <~ ")" ^^ { case dist ~ exp => Count(exp) } |
- COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
- APPROXIMATE ~> COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ {
- case exp => ApproxCountDistinct(exp)
- } |
- APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ {
- case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble)
- } |
- FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
- LAST ~> "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } |
- AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
- MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
- MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
- UPPER ~> "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) } |
- LOWER ~> "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } |
- IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
- case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
- } |
- (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ {
- case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE))
- } |
- (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
- case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l)
- } |
- SQRT ~> "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } |
- ABS ~> "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } |
- ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ {
- case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs)
- }
+ ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) }
+ | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) }
+ | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) }
+ | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) }
+ | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) }
+ | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^
+ { case exp => ApproxCountDistinct(exp) }
+ | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
+ { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) }
+ | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) }
+ | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) }
+ | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) }
+ | MIN ~ "(" ~> expression <~ ")" ^^ { case exp => Min(exp) }
+ | MAX ~ "(" ~> expression <~ ")" ^^ { case exp => Max(exp) }
+ | UPPER ~ "(" ~> expression <~ ")" ^^ { case exp => Upper(exp) }
+ | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
+ | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
+ { case c ~ t ~ f => If(c, t, f) }
+ | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
+ (ELSE ~> expression).? <~ END ^^ {
+ case casePart ~ altPart ~ elsePart =>
+ val altExprs = altPart.flatMap { case whenExpr ~ thenExpr =>
+ Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr)
+ }
+ CaseWhen(altExprs ++ elsePart.toList)
+ }
+ | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
+ { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
+ | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
+ { case s ~ p ~ l => Substring(s, p, l) }
+ | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) }
+ | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) }
+ | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
+ { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
+ )
protected lazy val cast: Parser[Expression] =
- CAST ~> "(" ~> expression ~ AS ~ dataType <~ ")" ^^ { case exp ~ _ ~ t => Cast(exp, t) }
+ CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ { case exp ~ t => Cast(exp, t) }
protected lazy val literal: Parser[Literal] =
- numericLit ^^ {
- case i if i.toLong > Int.MaxValue => Literal(i.toLong)
- case i => Literal(i.toInt)
- } |
- NULL ^^^ Literal(null, NullType) |
- floatLit ^^ {case f => Literal(f.toDouble) } |
- stringLit ^^ {case s => Literal(s, StringType) }
+ ( numericLit ^^ {
+ case i if i.toLong > Int.MaxValue => Literal(i.toLong)
+ case i => Literal(i.toInt)
+ }
+ | NULL ^^^ Literal(null, NullType)
+ | floatLit ^^ {case f => Literal(f.toDouble) }
+ | stringLit ^^ {case s => Literal(s, StringType) }
+ )
protected lazy val floatLit: Parser[String] =
elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars)
protected lazy val baseExpression: PackratParser[Expression] =
- expression ~ "[" ~ expression <~ "]" ^^ {
- case base ~ _ ~ ordinal => GetItem(base, ordinal)
- } |
- (expression <~ ".") ~ ident ^^ {
- case base ~ fieldName => GetField(base, fieldName)
- } |
- TRUE ^^^ Literal(true, BooleanType) |
- FALSE ^^^ Literal(false, BooleanType) |
- cast |
- "(" ~> expression <~ ")" |
- function |
- "-" ~> literal ^^ UnaryMinus |
- dotExpressionHeader |
- ident ^^ UnresolvedAttribute |
- "*" ^^^ Star(None) |
- literal
+ ( expression ~ ("[" ~> expression <~ "]") ^^
+ { case base ~ ordinal => GetItem(base, ordinal) }
+ | (expression <~ ".") ~ ident ^^
+ { case base ~ fieldName => GetField(base, fieldName) }
+ | TRUE ^^^ Literal(true, BooleanType)
+ | FALSE ^^^ Literal(false, BooleanType)
+ | cast
+ | "(" ~> expression <~ ")"
+ | function
+ | "-" ~> literal ^^ UnaryMinus
+ | dotExpressionHeader
+ | ident ^^ UnresolvedAttribute
+ | "*" ^^^ Star(None)
+ | literal
+ )
protected lazy val dotExpressionHeader: Parser[Expression] =
(ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
@@ -386,55 +338,3 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected lazy val dataType: Parser[DataType] =
STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
}
-
-class SqlLexical(val keywords: Seq[String]) extends StdLexical {
- case class FloatLit(chars: String) extends Token {
- override def toString = chars
- }
-
- reserved ++= keywords.flatMap(w => allCaseVersions(w))
-
- delimiters += (
- "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
- ",", ";", "%", "{", "}", ":", "[", "]", "."
- )
-
- override lazy val token: Parser[Token] = (
- identChar ~ rep( identChar | digit ) ^^
- { case first ~ rest => processIdent(first :: rest mkString "") }
- | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ {
- case i ~ None => NumericLit(i mkString "")
- case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString(""))
- }
- | '\'' ~ rep( chrExcept('\'', '\n', EofCh) ) ~ '\'' ^^
- { case '\'' ~ chars ~ '\'' => StringLit(chars mkString "") }
- | '\"' ~ rep( chrExcept('\"', '\n', EofCh) ) ~ '\"' ^^
- { case '\"' ~ chars ~ '\"' => StringLit(chars mkString "") }
- | EofCh ^^^ EOF
- | '\'' ~> failure("unclosed string literal")
- | '\"' ~> failure("unclosed string literal")
- | delim
- | failure("illegal character")
- )
-
- override def identChar = letter | elem('_')
-
- override def whitespace: Parser[Any] = rep(
- whitespaceChar
- | '/' ~ '*' ~ comment
- | '/' ~ '/' ~ rep( chrExcept(EofCh, '\n') )
- | '#' ~ rep( chrExcept(EofCh, '\n') )
- | '-' ~ '-' ~ rep( chrExcept(EofCh, '\n') )
- | '/' ~ '*' ~ failure("unclosed comment")
- )
-
- /** Generate all variations of upper and lower case of a given string */
- def allCaseVersions(s: String, prefix: String = ""): Stream[String] = {
- if (s == "") {
- Stream(prefix)
- } else {
- allCaseVersions(s.tail, prefix + s.head.toLower) ++
- allCaseVersions(s.tail, prefix + s.head.toUpper)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index ef1d12531f109..204904ecf04db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -137,6 +137,9 @@ class JoinedRow extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -226,6 +229,9 @@ class JoinedRow2 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -309,6 +315,9 @@ class JoinedRow3 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -392,6 +401,9 @@ class JoinedRow4 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
@@ -475,6 +487,9 @@ class JoinedRow5 extends Row {
def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
+ override def getAs[T](i: Int): T =
+ if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
+
def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index d68a4fabeac77..d00ec39774c35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -64,6 +64,7 @@ trait Row extends Seq[Any] with Serializable {
def getShort(i: Int): Short
def getByte(i: Int): Byte
def getString(i: Int): String
+ def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
override def toString() =
s"[${this.mkString(",")}]"
@@ -118,6 +119,7 @@ object EmptyRow extends Row {
def getShort(i: Int): Short = throw new UnsupportedOperationException
def getByte(i: Int): Byte = throw new UnsupportedOperationException
def getString(i: Int): String = throw new UnsupportedOperationException
+ override def getAs[T](i: Int): T = throw new UnsupportedOperationException
def copy() = this
}
@@ -217,19 +219,19 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(0)
- override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
- override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
- override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
- override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value }
- override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value }
- override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value }
- override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value }
+ override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value }
+ override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value }
+ override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value }
+ override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
+ override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
+ override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
+ override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value }
override def setNullAt(i: Int): Unit = { values(i) = null }
- override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value }
+ override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
- override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value }
+ override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
override def copy() = new GenericRow(values.clone())
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
similarity index 97%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 9cbab3d5d0d0d..570379c533e1f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -233,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def iterator: Iterator[Any] = values.map(_.boxed).iterator
- def setString(ordinal: Int, value: String) = update(ordinal, value)
+ override def setString(ordinal: Int, value: String) = update(ordinal, value)
- def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
+ override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
override def setInt(ordinal: Int, value: Int): Unit = {
val currentValue = values(ordinal).asInstanceOf[MutableInt]
@@ -306,4 +306,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def getByte(i: Int): Byte = {
values(i).asInstanceOf[MutableByte].value
}
+
+ override def getAs[T](i: Int): T = {
+ values(i).boxed.asInstanceOf[T]
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 329af332d0fa1..1e22b2d03c672 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.BooleanType
-
object InterpretedPredicate {
def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
apply(BindReferences.bindReference(expression, inputSchema))
@@ -95,6 +95,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}
+/**
+ * Optimized version of In clause, when all filter values of In clause are
+ * static.
+ */
+case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression])
+ extends Predicate {
+
+ def children = child
+
+ def nullable = true // TODO: Figure out correct nullability semantics of IN.
+ override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"
+
+ override def eval(input: Row): Any = {
+ hset.contains(value.eval(input))
+ }
+}
+
case class And(left: Expression, right: Expression) extends BinaryPredicate {
def symbol = "&&"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 636d0b95583e4..3693b41404fd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
+import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -38,7 +39,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
BooleanSimplification,
SimplifyFilters,
SimplifyCasts,
- SimplifyCaseConversionExpressions) ::
+ SimplifyCaseConversionExpressions,
+ OptimizeIn) ::
Batch("Filter Pushdown", FixedPoint(100),
UnionPushdown,
CombineFilters,
@@ -273,6 +275,20 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
}
+/**
+ * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]]
+ * which is much faster
+ */
+object OptimizeIn extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+ val hSet = list.map(e => e.eval(null))
+ InSet(v, HashSet() ++ hSet, v +: list)
+ }
+ }
+}
+
/**
* Simplifies boolean expressions where the answer can be determined without evaluating both sides.
* Note that this rule can eliminate expressions that might otherwise have been evaluated and thus
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index 9a3848cfc6b62..b8ba2ee428a20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -39,9 +39,9 @@ case class NativeCommand(cmd: String) extends Command {
}
/**
- * Commands of the form "SET (key) (= value)".
+ * Commands of the form "SET [key [= value] ]".
*/
-case class SetCommand(key: Option[String], value: Option[String]) extends Command {
+case class SetCommand(kv: Option[(String, Option[String])]) extends Command {
override def output = Seq(
AttributeReference("", StringType, nullable = false)())
}
@@ -81,3 +81,14 @@ case class DescribeCommand(
AttributeReference("data_type", StringType, nullable = false)(),
AttributeReference("comment", StringType, nullable = false)())
}
+
+/**
+ * Returned for the "! shellCommand" command
+ */
+case class ShellCommand(cmd: String) extends Command
+
+
+/**
+ * Returned for the "SOURCE file" command
+ */
+case class SourceCommand(filePath: String) extends Command
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 63931af4bac3d..692ed78a7292c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -19,12 +19,15 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp
+import scala.collection.immutable.HashSet
+
import org.scalatest.FunSuite
import org.scalatest.Matchers._
import org.scalautils.TripleEqualsSupport.Spread
import org.apache.spark.sql.catalyst.types._
+
/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -145,6 +148,24 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true)
}
+ test("INSET") {
+ val hS = HashSet[Any]() + 1 + 2
+ val nS = HashSet[Any]() + 1 + 2 + null
+ val one = Literal(1)
+ val two = Literal(2)
+ val three = Literal(3)
+ val nl = Literal(null)
+ val s = Seq(one, two)
+ val nullS = Seq(one, two, null)
+ checkEvaluation(InSet(one, hS, one +: s), true)
+ checkEvaluation(InSet(two, hS, two +: s), true)
+ checkEvaluation(InSet(two, nS, two +: nullS), true)
+ checkEvaluation(InSet(nl, nS, nl +: nullS), true)
+ checkEvaluation(InSet(three, hS, three +: s), false)
+ checkEvaluation(InSet(three, nS, three +: nullS), false)
+ checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
+ }
+
test("MaxOf") {
checkEvaluation(MaxOf(1, 2), 2)
checkEvaluation(MaxOf(2, 1), 2)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
new file mode 100644
index 0000000000000..97a78ec971c39
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.immutable.HashSet
+import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types._
+
+// For implicit conversions
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class OptimizeInSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("AnalysisNodes", Once,
+ EliminateAnalysisOperators) ::
+ Batch("ConstantFolding", Once,
+ ConstantFolding,
+ BooleanSimplification,
+ OptimizeIn) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ test("OptimizedIn test: In clause optimized to InSet") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2))))
+ .analyze
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2,
+ UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2))))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("OptimizedIn test: In clause not optimized in case filter has attributes") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
+ .analyze
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b"))))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
index 3bf7382ac67a6..5ab2b5316ab10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.columnar.InMemoryRelation
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
+import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
/** Holds a cached logical plan and its data */
private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation)
@@ -74,10 +74,14 @@ private[sql] trait CacheManager {
cachedData.clear()
}
- /** Caches the data produced by the logical representation of the given schema rdd. */
+ /**
+ * Caches the data produced by the logical representation of the given schema rdd. Unlike
+ * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing
+ * the in-memory columnar representation of the underlying table is expensive.
+ */
private[sql] def cacheQuery(
query: SchemaRDD,
- storageLevel: StorageLevel = MEMORY_ONLY): Unit = writeLock {
+ storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.optimizedPlan
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index f6f4cf3b80d41..07e6e2eccddf4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -35,6 +35,7 @@ private[spark] object SQLConf {
val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata"
val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec"
+ val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord"
// This is only used for the thriftserver
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
@@ -131,6 +132,9 @@ private[sql] trait SQLConf {
private[spark] def inMemoryPartitionPruning: Boolean =
getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean
+ private[spark] def columnNameOfCorruptRecord: String =
+ getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record")
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 35561cac3e5e1..23e7b2d270777 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -66,12 +66,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = true)
+
@transient
protected[sql] val optimizer = Optimizer
+
@transient
- protected[sql] val parser = new catalyst.SqlParser
+ protected[sql] val sqlParser = {
+ val fallback = new catalyst.SqlParser
+ new catalyst.SparkSQLParser(fallback(_))
+ }
- protected[sql] def parseSql(sql: String): LogicalPlan = parser(sql)
+ protected[sql] def parseSql(sql: String): LogicalPlan = sqlParser(sql)
protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql))
protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
@@ -195,9 +200,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = {
+ val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
val appliedSchema =
- Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0)))
- val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
+ Option(schema).getOrElse(
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
}
@@ -206,8 +214,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = {
- val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))
- val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema)
+ val columnNameOfCorruptJsonRecord = columnNameOfCorruptRecord
+ val appliedSchema =
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
+ val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
applySchema(rowRDD, appliedSchema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 594bf8ffc20e1..948122d42f0e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -360,7 +360,7 @@ class SchemaRDD(
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None) =
- new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan))
+ new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan))
/**
* Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index c006c4330ff66..f8171c3be3207 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -148,8 +148,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
* It goes through the entire dataset once to determine the schema.
*/
def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = {
- val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))
- val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
+ val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
+ val appliedScalaSchema =
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(json.rdd, 1.0, columnNameOfCorruptJsonRecord))
+ val scalaRowRDD =
+ JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord)
val logicalPlan =
LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
@@ -162,10 +166,14 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
*/
@Experimental
def jsonRDD(json: JavaRDD[String], schema: StructType): JavaSchemaRDD = {
+ val columnNameOfCorruptJsonRecord = sqlContext.columnNameOfCorruptRecord
val appliedScalaSchema =
Option(asScalaDataType(schema)).getOrElse(
- JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[SStructType]
- val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema)
+ JsonRDD.nullTypeToStringType(
+ JsonRDD.inferSchema(
+ json.rdd, 1.0, columnNameOfCorruptJsonRecord))).asInstanceOf[SStructType]
+ val scalaRowRDD = JsonRDD.jsonStringToRow(
+ json.rdd, appliedScalaSchema, columnNameOfCorruptJsonRecord)
val logicalPlan =
LogicalRDD(appliedScalaSchema.toAttributes, scalaRowRDD)(sqlContext)
new JavaSchemaRDD(sqlContext, logicalPlan)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 4f79173a26f88..22ab0e2613f21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -38,7 +38,7 @@ private[sql] object InMemoryRelation {
new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child)()
}
-private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row)
+private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: Row)
private[sql] case class InMemoryRelation(
output: Seq[Attribute],
@@ -91,7 +91,7 @@ private[sql] case class InMemoryRelation(
val stats = Row.fromSeq(
columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _))
- CachedBatch(columnBuilders.map(_.build()), stats)
+ CachedBatch(columnBuilders.map(_.build().array()), stats)
}
def hasNext = rowIterator.hasNext
@@ -238,8 +238,9 @@ private[sql] case class InMemoryColumnarTableScan(
def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
val rows = cacheBatches.flatMap { cachedBatch =>
// Build column accessors
- val columnAccessors =
- requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+ val columnAccessors = requestedColumnIndices.map { batch =>
+ ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch)))
+ }
// Extract rows via column accessors
new Iterator[Row] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index c386fd121c5de..38877c28de3a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -39,7 +39,8 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
- protected def generatorOutput: Seq[Attribute] = {
+ // This must be a val since the generator output expr ids are not preserved by serialization.
+ protected val generatorOutput: Seq[Attribute] = {
if (join && outer) {
generator.output.map(_.withNullability(true))
} else {
@@ -62,7 +63,7 @@ case class Generate(
newProjection(child.output ++ nullValues, child.output)
val joinProjection =
- newProjection(child.output ++ generator.output, child.output ++ generator.output)
+ newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput)
val joinedRow = new JoinedRow
iter.flatMap {row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index bbf17b9fadf86..4f1af7234d551 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -304,8 +304,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case class CommandStrategy(context: SQLContext) extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.SetCommand(key, value) =>
- Seq(execution.SetCommand(key, value, plan.output)(context))
+ case logical.SetCommand(kv) =>
+ Seq(execution.SetCommand(kv, plan.output)(context))
case logical.ExplainCommand(logicalPlan, extended) =>
Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context))
case logical.CacheTableCommand(tableName, optPlan, isLazy) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index d49633c24ad4d..5859eba408ee1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -48,29 +48,28 @@ trait Command {
* :: DeveloperApi ::
*/
@DeveloperApi
-case class SetCommand(
- key: Option[String], value: Option[String], output: Seq[Attribute])(
+case class SetCommand(kv: Option[(String, Option[String])], output: Seq[Attribute])(
@transient context: SQLContext)
extends LeafNode with Command with Logging {
- override protected lazy val sideEffectResult: Seq[Row] = (key, value) match {
- // Set value for key k.
- case (Some(k), Some(v)) =>
- if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
+ override protected lazy val sideEffectResult: Seq[Row] = kv match {
+ // Set value for the key.
+ case Some((key, Some(value))) =>
+ if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.")
- context.setConf(SQLConf.SHUFFLE_PARTITIONS, v)
- Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v"))
+ context.setConf(SQLConf.SHUFFLE_PARTITIONS, value)
+ Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value"))
} else {
- context.setConf(k, v)
- Seq(Row(s"$k=$v"))
+ context.setConf(key, value)
+ Seq(Row(s"$key=$value"))
}
- // Query the value bound to key k.
- case (Some(k), _) =>
+ // Query the value bound to the key.
+ case Some((key, None)) =>
// TODO (lian) This is just a workaround to make the Simba ODBC driver work.
// Should remove this once we get the ODBC driver updated.
- if (k == "-v") {
+ if (key == "-v") {
val hiveJars = Seq(
"hive-exec-0.12.0.jar",
"hive-service-0.12.0.jar",
@@ -84,23 +83,20 @@ case class SetCommand(
Row("system:java.class.path=" + hiveJars),
Row("system:sun.java.command=shark.SharkServer2"))
} else {
- if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
+ if (key == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) {
logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " +
s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.")
Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}"))
} else {
- Seq(Row(s"$k=${context.getConf(k, "")}"))
+ Seq(Row(s"$key=${context.getConf(key, "")}"))
}
}
// Query all key-value pairs that are set in the SQLConf of the context.
- case (None, None) =>
+ case _ =>
context.getAllConfs.map { case (k, v) =>
Row(s"$k=$v")
}.toSeq
-
- case _ =>
- throw new IllegalArgumentException()
}
override def otherCopyArgs = context :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 0f27fd13e7379..61ee960aad9d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.json
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
+import java.sql.Timestamp
+import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.spark.rdd.RDD
@@ -34,16 +36,19 @@ private[sql] object JsonRDD extends Logging {
private[sql] def jsonStringToRow(
json: RDD[String],
- schema: StructType): RDD[Row] = {
- parseJson(json).map(parsed => asRow(parsed, schema))
+ schema: StructType,
+ columnNameOfCorruptRecords: String): RDD[Row] = {
+ parseJson(json, columnNameOfCorruptRecords).map(parsed => asRow(parsed, schema))
}
private[sql] def inferSchema(
json: RDD[String],
- samplingRatio: Double = 1.0): StructType = {
+ samplingRatio: Double = 1.0,
+ columnNameOfCorruptRecords: String): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1)
- val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _)
+ val allKeys =
+ parseJson(schemaData, columnNameOfCorruptRecords).map(allKeysWithValueTypes).reduce(_ ++ _)
createSchema(allKeys)
}
@@ -273,7 +278,9 @@ private[sql] object JsonRDD extends Logging {
case atom => atom
}
- private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = {
+ private def parseJson(
+ json: RDD[String],
+ columnNameOfCorruptRecords: String): RDD[Map[String, Any]] = {
// According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72],
// ObjectMapper will not return BigDecimal when
// "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled
@@ -288,12 +295,16 @@ private[sql] object JsonRDD extends Logging {
// For example: for {"key": 1, "key":2}, we will get "key"->2.
val mapper = new ObjectMapper()
iter.flatMap { record =>
- val parsed = mapper.readValue(record, classOf[Object]) match {
- case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
- case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
- }
+ try {
+ val parsed = mapper.readValue(record, classOf[Object]) match {
+ case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
+ case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
+ }
- parsed
+ parsed
+ } catch {
+ case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil
+ }
}
})
}
@@ -361,6 +372,14 @@ private[sql] object JsonRDD extends Logging {
}
}
+ private def toTimestamp(value: Any): Timestamp = {
+ value match {
+ case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong)
+ case value: java.lang.Long => new Timestamp(value)
+ case value: java.lang.String => Timestamp.valueOf(value)
+ }
+ }
+
private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={
if (value == null) {
null
@@ -377,6 +396,7 @@ private[sql] object JsonRDD extends Logging {
case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
+ case TimestampType => toTimestamp(value)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1e624f97004f5..444bc95009c31 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.storage.RDDBlockId
+import org.apache.spark.storage.{StorageLevel, RDDBlockId}
case class BigData(s: String)
@@ -55,10 +55,10 @@ class CachedTableSuite extends QueryTest {
test("too big for memory") {
val data = "*" * 10000
- sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData")
- cacheTable("bigData")
- assert(table("bigData").count() === 1000000L)
- uncacheTable("bigData")
+ sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData")
+ table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
+ assert(table("bigData").count() === 200000L)
+ table("bigData").unpersist()
}
test("calling .cache() should use in-memory columnar caching") {
@@ -69,7 +69,7 @@ class CachedTableSuite extends QueryTest {
test("calling .unpersist() should drop in-memory columnar cache") {
table("testData").cache()
table("testData").count()
- table("testData").unpersist(true)
+ table("testData").unpersist(blocking = true)
assertCached(table("testData"), 0)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index d001abb7e1fcc..45e58afe9d9a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -147,6 +147,14 @@ class DslQuerySuite extends QueryTest {
(1, 1, 1, 2) :: Nil)
}
+ test("SPARK-3858 generator qualifiers are discarded") {
+ checkAnswer(
+ arrayData.as('ad)
+ .generate(Explode("data" :: Nil, 'data), alias = Some("ex"))
+ .select("ex.data".attr),
+ Seq(1, 2, 3, 2, 3, 4).map(Seq(_)))
+ }
+
test("average") {
checkAnswer(
testData2.groupBy()(avg('a)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b9b196ea5a46a..a94022c0cf6e3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -42,7 +42,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
TimeZone.setDefault(origZone)
}
-
test("SPARK-3176 Added Parser of SQL ABS()") {
checkAnswer(
sql("SELECT ABS(-1.3)"),
@@ -61,7 +60,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
4)
}
-
test("SPARK-2041 column name equals tablename") {
checkAnswer(
sql("SELECT tableName FROM tableName"),
@@ -680,9 +678,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}
-
+
test("SPARK-3371 Renaming a function expression with group by gives error") {
registerFunction("len", (s: String) => s.length)
checkAnswer(
- sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)}
+ sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
+ }
+
+ test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
+ checkAnswer(
+ sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
+ }
+
+ test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
+ checkAnswer(
+ sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 685e788207725..7bb08f1b513ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,8 +21,12 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import java.sql.Timestamp
+
class JsonSuite extends QueryTest {
import TestJsonData._
TestJsonData
@@ -50,6 +54,12 @@ class JsonSuite extends QueryTest {
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType))
+
+ checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
+ checkTypePromotion(new Timestamp(intNumber.toLong),
+ enforceCorrectType(intNumber.toLong, TimestampType))
+ val strDate = "2014-09-30 12:34:56"
+ checkTypePromotion(Timestamp.valueOf(strDate), enforceCorrectType(strDate, TimestampType))
}
test("Get compatible type") {
@@ -636,7 +646,65 @@ class JsonSuite extends QueryTest {
("str_a_1", null, null) ::
("str_a_2", null, null) ::
(null, "str_b_3", null) ::
- ("str_a_4", "str_b_4", "str_c_4") ::Nil
+ ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ )
+ }
+
+ test("Corrupt records") {
+ // Test if we can query corrupt records.
+ val oldColumnNameOfCorruptRecord = TestSQLContext.columnNameOfCorruptRecord
+ TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
+
+ val jsonSchemaRDD = jsonRDD(corruptRecords)
+ jsonSchemaRDD.registerTempTable("jsonTable")
+
+ val schema = StructType(
+ StructField("_unparsed", StringType, true) ::
+ StructField("a", StringType, true) ::
+ StructField("b", StringType, true) ::
+ StructField("c", StringType, true) :: Nil)
+
+ assert(schema === jsonSchemaRDD.schema)
+
+ // In HiveContext, backticks should be used to access columns starting with a underscore.
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, b, c, _unparsed
+ |FROM jsonTable
+ """.stripMargin),
+ (null, null, null, "{") ::
+ (null, null, null, "") ::
+ (null, null, null, """{"a":1, b:2}""") ::
+ (null, null, null, """{"a":{, b:3}""") ::
+ ("str_a_4", "str_b_4", "str_c_4", null) ::
+ (null, null, null, "]") :: Nil
)
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT a, b, c
+ |FROM jsonTable
+ |WHERE _unparsed IS NULL
+ """.stripMargin),
+ ("str_a_4", "str_b_4", "str_c_4") :: Nil
+ )
+
+ checkAnswer(
+ sql(
+ """
+ |SELECT _unparsed
+ |FROM jsonTable
+ |WHERE _unparsed IS NOT NULL
+ """.stripMargin),
+ Seq("{") ::
+ Seq("") ::
+ Seq("""{"a":1, b:2}""") ::
+ Seq("""{"a":{, b:3}""") ::
+ Seq("]") :: Nil
+ )
+
+ TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
index fc833b8b54e4c..eaca9f0508a12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
@@ -143,4 +143,13 @@ object TestJsonData {
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
+
+ val corruptRecords =
+ TestSQLContext.sparkContext.parallelize(
+ """{""" ::
+ """""" ::
+ """{"a":1, b:2}""" ::
+ """{"a":{, b:3}""" ::
+ """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
+ """]""" :: Nil)
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index 910174a153768..accf61576b804 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -172,7 +172,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
result = hiveContext.sql(statement)
logDebug(result.queryExecution.toString())
result.queryExecution.logical match {
- case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) =>
+ case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value)))) =>
sessionToActivePool(parentSession) = value
logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
case _ =>
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 3475c2c9db080..d68dd090b5e6c 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -62,9 +62,11 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
def captureOutput(source: String)(line: String) {
buffer += s"$source> $line"
- if (line.contains(expectedAnswers(next.get()))) {
- if (next.incrementAndGet() == expectedAnswers.size) {
- foundAllExpectedAnswers.trySuccess(())
+ if (next.get() < expectedAnswers.size) {
+ if (line.startsWith(expectedAnswers(next.get()))) {
+ if (next.incrementAndGet() == expectedAnswers.size) {
+ foundAllExpectedAnswers.trySuccess(())
+ }
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
index c5844e92eaaa9..430ffb29989ea 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
@@ -18,118 +18,50 @@
package org.apache.spark.sql.hive
import scala.language.implicitConversions
-import scala.util.parsing.combinator.syntactical.StandardTokenParsers
-import scala.util.parsing.combinator.PackratParsers
+
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.SqlLexical
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, SqlLexical}
/**
- * A parser that recognizes all HiveQL constructs together with several Spark SQL specific
- * extensions like CACHE TABLE and UNCACHE TABLE.
+ * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions.
*/
-private[hive] class ExtendedHiveQlParser extends StandardTokenParsers with PackratParsers {
-
- def apply(input: String): LogicalPlan = {
- // Special-case out set commands since the value fields can be
- // complex to handle without RegexParsers. Also this approach
- // is clearer for the several possible cases of set commands.
- if (input.trim.toLowerCase.startsWith("set")) {
- input.trim.drop(3).split("=", 2).map(_.trim) match {
- case Array("") => // "set"
- SetCommand(None, None)
- case Array(key) => // "set key"
- SetCommand(Some(key), None)
- case Array(key, value) => // "set key=value"
- SetCommand(Some(key), Some(value))
- }
- } else if (input.trim.startsWith("!")) {
- ShellCommand(input.drop(1))
- } else {
- phrase(query)(new lexical.Scanner(input)) match {
- case Success(r, x) => r
- case x => sys.error(x.toString)
- }
- }
- }
-
- protected case class Keyword(str: String)
-
- protected val ADD = Keyword("ADD")
- protected val AS = Keyword("AS")
- protected val CACHE = Keyword("CACHE")
- protected val DFS = Keyword("DFS")
- protected val FILE = Keyword("FILE")
- protected val JAR = Keyword("JAR")
- protected val LAZY = Keyword("LAZY")
- protected val SET = Keyword("SET")
- protected val SOURCE = Keyword("SOURCE")
- protected val TABLE = Keyword("TABLE")
- protected val UNCACHE = Keyword("UNCACHE")
-
+private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
protected implicit def asParser(k: Keyword): Parser[String] =
lexical.allCaseVersions(k.str).map(x => x : Parser[String]).reduce(_ | _)
- protected def allCaseConverse(k: String): Parser[String] =
- lexical.allCaseVersions(k).map(x => x : Parser[String]).reduce(_ | _)
+ protected val ADD = Keyword("ADD")
+ protected val DFS = Keyword("DFS")
+ protected val FILE = Keyword("FILE")
+ protected val JAR = Keyword("JAR")
- protected val reservedWords =
- this.getClass
+ private val reservedWords =
+ this
+ .getClass
.getMethods
.filter(_.getReturnType == classOf[Keyword])
.map(_.invoke(this).asInstanceOf[Keyword].str)
override val lexical = new SqlLexical(reservedWords)
- protected lazy val query: Parser[LogicalPlan] =
- cache | uncache | addJar | addFile | dfs | source | hiveQl
+ protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl
protected lazy val hiveQl: Parser[LogicalPlan] =
restInput ^^ {
- case statement => HiveQl.createPlan(statement.trim())
- }
-
- // Returns the whole input string
- protected lazy val wholeInput: Parser[String] = new Parser[String] {
- def apply(in: Input) =
- Success(in.source.toString, in.drop(in.source.length()))
- }
-
- // Returns the rest of the input string that are not parsed yet
- protected lazy val restInput: Parser[String] = new Parser[String] {
- def apply(in: Input) =
- Success(
- in.source.subSequence(in.offset, in.source.length).toString,
- in.drop(in.source.length()))
- }
-
- protected lazy val cache: Parser[LogicalPlan] =
- CACHE ~> opt(LAZY) ~ (TABLE ~> ident) ~ opt(AS ~> hiveQl) ^^ {
- case isLazy ~ tableName ~ plan =>
- CacheTableCommand(tableName, plan, isLazy.isDefined)
- }
-
- protected lazy val uncache: Parser[LogicalPlan] =
- UNCACHE ~ TABLE ~> ident ^^ {
- case tableName => UncacheTableCommand(tableName)
+ case statement => HiveQl.createPlan(statement.trim)
}
- protected lazy val addJar: Parser[LogicalPlan] =
- ADD ~ JAR ~> restInput ^^ {
- case jar => AddJar(jar.trim())
+ protected lazy val dfs: Parser[LogicalPlan] =
+ DFS ~> wholeInput ^^ {
+ case command => NativeCommand(command.trim)
}
- protected lazy val addFile: Parser[LogicalPlan] =
+ private lazy val addFile: Parser[LogicalPlan] =
ADD ~ FILE ~> restInput ^^ {
- case file => AddFile(file.trim())
+ case input => AddFile(input.trim)
}
- protected lazy val dfs: Parser[LogicalPlan] =
- DFS ~> wholeInput ^^ {
- case command => NativeCommand(command.trim())
- }
-
- protected lazy val source: Parser[LogicalPlan] =
- SOURCE ~> restInput ^^ {
- case file => SourceCommand(file.trim())
+ private lazy val addJar: Parser[LogicalPlan] =
+ ADD ~ JAR ~> restInput ^^ {
+ case input => AddJar(input.trim)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 180990877bafb..19047f9868b70 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
+import org.apache.spark.sql.catalyst.SparkSQLParser
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -40,10 +41,6 @@ import scala.collection.JavaConversions._
*/
private[hive] case object NativePlaceholder extends Command
-private[hive] case class ShellCommand(cmd: String) extends Command
-
-private[hive] case class SourceCommand(filePath: String) extends Command
-
private[hive] case class AddFile(filePath: String) extends Command
private[hive] case class AddJar(path: String) extends Command
@@ -128,9 +125,11 @@ private[hive] object HiveQl {
"TOK_CREATETABLE",
"TOK_DESCTABLE"
) ++ nativeCommands
-
- // It parses hive sql query along with with several Spark SQL specific extensions
- protected val hiveSqlParser = new ExtendedHiveQlParser
+
+ protected val hqlParser = {
+ val fallback = new ExtendedHiveQlParser
+ new SparkSQLParser(fallback(_))
+ }
/**
* A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations
@@ -231,7 +230,7 @@ private[hive] object HiveQl {
/** Returns a LogicalPlan for a given HiveQL string. */
- def parseSql(sql: String): LogicalPlan = hiveSqlParser(sql)
+ def parseSql(sql: String): LogicalPlan = hqlParser(sql)
/** Creates LogicalPlan for a given HiveQL string. */
def createPlan(sql: String) = {
@@ -652,7 +651,7 @@ private[hive] object HiveQl {
def nodeToRelation(node: Node): LogicalPlan = node match {
case Token("TOK_SUBQUERY",
query :: Token(alias, Nil) :: Nil) =>
- Subquery(alias, nodeToPlan(query))
+ Subquery(cleanIdentifier(alias), nodeToPlan(query))
case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
val Token("TOK_SELECT",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 508d8239c7628..5c66322f1ed99 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -167,10 +167,10 @@ private[hive] trait HiveStrategies {
database.get,
tableName,
query,
- InsertIntoHiveTable(_: MetastoreRelation,
- Map(),
- query,
- true)(hiveContext)) :: Nil
+ InsertIntoHiveTable(_: MetastoreRelation,
+ Map(),
+ query,
+ overwrite = true)(hiveContext)) :: Nil
case _ => Nil
}
}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
new file mode 100644
index 0000000000000..6c4f378bc5471
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFIntegerToString extends UDF {
+ public String evaluate(Integer i) {
+ return i.toString();
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
new file mode 100644
index 0000000000000..d2d39a8c4dc28
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+import java.util.List;
+
+public class UDFListListInt extends UDF {
+ /**
+ *
+ * @param obj
+ * SQL schema: array>
+ * Java Type: List>
+ * @return
+ */
+ public long evaluate(Object obj) {
+ if (obj == null) {
+ return 0l;
+ }
+ List listList = (List) obj;
+ long retVal = 0;
+ for (List aList : listList) {
+ @SuppressWarnings("unchecked")
+ List list = (List) aList;
+ @SuppressWarnings("unchecked")
+ Integer someInt = (Integer) list.get(1);
+ try {
+ retVal += (long) (someInt.intValue());
+ } catch (NullPointerException e) {
+ System.out.println(e);
+ }
+ }
+ return retVal;
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java
new file mode 100644
index 0000000000000..efd34df293c88
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+import java.util.List;
+import org.apache.commons.lang.StringUtils;
+
+public class UDFListString extends UDF {
+
+ public String evaluate(Object a) {
+ if (a == null) {
+ return null;
+ }
+ @SuppressWarnings("unchecked")
+ List s = (List) a;
+
+ return StringUtils.join(s, ',');
+ }
+
+
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java
new file mode 100644
index 0000000000000..a369188d471e8
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFStringString extends UDF {
+ public String evaluate(String s1, String s2) {
+ return s1 + " " + s2;
+ }
+}
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java
new file mode 100644
index 0000000000000..0165591a7ce78
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution;
+
+import org.apache.hadoop.hive.ql.exec.UDF;
+
+public class UDFTwoListList extends UDF {
+ public String evaluate(Object o1, Object o2) {
+ UDFListListInt udf = new UDFListListInt();
+
+ return String.format("%s, %s", udf.evaluate(o1), udf.evaluate(o2));
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index e4324e9528f9b..872f28d514efe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -17,33 +17,37 @@
package org.apache.spark.sql.hive.execution
-import java.io.{DataOutput, DataInput}
+import java.io.{DataInput, DataOutput}
import java.util
import java.util.Properties
-import org.apache.spark.util.Utils
-
-import scala.collection.JavaConversions._
-
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe}
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector}
-
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
-
-import org.apache.spark.sql.Row
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
+import org.apache.hadoop.io.Writable
+import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHive
-import org.apache.spark.sql.hive.test.TestHive._
+
+import org.apache.spark.util.Utils
+
+import scala.collection.JavaConversions._
case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int)
+// Case classes for the custom UDF's.
+case class IntegerCaseClass(i: Int)
+case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)])
+case class StringCaseClass(s: String)
+case class ListStringCaseClass(l: Seq[String])
+
/**
* A test suite for Hive custom UDFs.
*/
-class HiveUdfSuite extends HiveComparisonTest {
+class HiveUdfSuite extends QueryTest {
+ import TestHive._
test("spark sql udf test that returns a struct") {
registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
@@ -81,7 +85,84 @@ class HiveUdfSuite extends HiveComparisonTest {
}
test("SPARK-2693 udaf aggregates test") {
- assert(sql("SELECT percentile(key,1) FROM src").first === sql("SELECT max(key) FROM src").first)
+ checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"),
+ sql("SELECT max(key) FROM src").collect().toSeq)
+ }
+
+ test("UDFIntegerToString") {
+ val testData = TestHive.sparkContext.parallelize(
+ IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
+ testData.registerTempTable("integerTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFIntegerToString(i) FROM integerTable"), //.collect(),
+ Seq(Seq("1"), Seq("2")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString")
+
+ TestHive.reset()
+ }
+
+ test("UDFListListInt") {
+ val testData = TestHive.sparkContext.parallelize(
+ ListListIntCaseClass(Nil) ::
+ ListListIntCaseClass(Seq((1, 2, 3))) ::
+ ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil)
+ testData.registerTempTable("listListIntTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), //.collect(),
+ Seq(Seq(0), Seq(2), Seq(13)))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt")
+
+ TestHive.reset()
+ }
+
+ test("UDFListString") {
+ val testData = TestHive.sparkContext.parallelize(
+ ListStringCaseClass(Seq("a", "b", "c")) ::
+ ListStringCaseClass(Seq("d", "e")) :: Nil)
+ testData.registerTempTable("listStringTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFListString(l) FROM listStringTable"), //.collect(),
+ Seq(Seq("a,b,c"), Seq("d,e")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString")
+
+ TestHive.reset()
+ }
+
+ test("UDFStringString") {
+ val testData = TestHive.sparkContext.parallelize(
+ StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil)
+ testData.registerTempTable("stringTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
+ checkAnswer(
+ sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), //.collect(),
+ Seq(Seq("hello world"), Seq("hello goodbye")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf")
+
+ TestHive.reset()
+ }
+
+ test("UDFTwoListList") {
+ val testData = TestHive.sparkContext.parallelize(
+ ListListIntCaseClass(Nil) ::
+ ListListIntCaseClass(Seq((1, 2, 3))) ::
+ ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) ::
+ Nil)
+ testData.registerTempTable("TwoListTable")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
+ checkAnswer(
+ sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), //.collect(),
+ Seq(Seq("0, 0"), Seq("2, 2"), Seq("13, 13")))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
+
+ TestHive.reset()
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 3647bb1c4ce7d..fbe6ac765c009 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -68,5 +68,11 @@ class SQLQuerySuite extends QueryTest {
checkAnswer(
sql("SELECT k FROM (SELECT `key` AS `k` FROM src) a"),
sql("SELECT `key` FROM src").collect().toSeq)
- }
+ }
+
+ test("SPARK-3834 Backticks not correctly handled in subquery aliases") {
+ checkAnswer(
+ sql("SELECT a.key FROM (SELECT key FROM src) `a`"),
+ sql("SELECT `key` FROM src").collect().toSeq)
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index a6184de4e83c1..2a7004e56ef53 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
- /**
+ /**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
* of the RDD.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
new file mode 100644
index 0000000000000..213dff6a76354
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.python
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.lang.reflect.Proxy
+import java.util.{ArrayList => JArrayList, List => JList}
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import py4j.GatewayServer
+
+import org.apache.spark.api.java._
+import org.apache.spark.api.python._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Interval, Duration, Time}
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.api.java._
+
+
+/**
+ * Interface for Python callback function which is used to transform RDDs
+ */
+private[python] trait PythonTransformFunction {
+ def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
+}
+
+/**
+ * Interface for Python Serializer to serialize PythonTransformFunction
+ */
+private[python] trait PythonTransformFunctionSerializer {
+ def dumps(id: String): Array[Byte]
+ def loads(bytes: Array[Byte]): PythonTransformFunction
+}
+
+/**
+ * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J)
+ * so that it looks like a Scala function and can be transparently serialized and
+ * deserialized by Java.
+ */
+private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction)
+ extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
+
+ def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
+ Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
+ .map(_.rdd)
+ }
+
+ def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
+ val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
+ Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
+ }
+
+ // for function.Function2
+ def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
+ pfunc.call(time.milliseconds, rdds)
+ }
+
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ val bytes = PythonTransformFunctionSerializer.serialize(pfunc)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ val length = in.readInt()
+ val bytes = new Array[Byte](length)
+ in.readFully(bytes)
+ pfunc = PythonTransformFunctionSerializer.deserialize(bytes)
+ }
+}
+
+/**
+ * Helpers for PythonTransformFunctionSerializer
+ *
+ * PythonTransformFunctionSerializer is logically a singleton that's happens to be
+ * implemented as a Python object.
+ */
+private[python] object PythonTransformFunctionSerializer {
+
+ /**
+ * A serializer in Python, used to serialize PythonTransformFunction
+ */
+ private var serializer: PythonTransformFunctionSerializer = _
+
+ /*
+ * Register a serializer from Python, should be called during initialization
+ */
+ def register(ser: PythonTransformFunctionSerializer): Unit = {
+ serializer = ser
+ }
+
+ def serialize(func: PythonTransformFunction): Array[Byte] = {
+ assert(serializer != null, "Serializer has not been registered!")
+ // get the id of PythonTransformFunction in py4j
+ val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
+ val f = h.getClass().getDeclaredField("id")
+ f.setAccessible(true)
+ val id = f.get(h).asInstanceOf[String]
+ serializer.dumps(id)
+ }
+
+ def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
+ assert(serializer != null, "Serializer has not been registered!")
+ serializer.loads(bytes)
+ }
+}
+
+/**
+ * Helper functions, which are called from Python via Py4J.
+ */
+private[python] object PythonDStream {
+
+ /**
+ * can not access PythonTransformFunctionSerializer.register() via Py4j
+ * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
+ */
+ def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = {
+ PythonTransformFunctionSerializer.register(ser)
+ }
+
+ /**
+ * Update the port of callback client to `port`
+ */
+ def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = {
+ val cl = gws.getCallbackClient
+ val f = cl.getClass.getDeclaredField("port")
+ f.setAccessible(true)
+ f.setInt(cl, port)
+ }
+
+ /**
+ * helper function for DStream.foreachRDD(),
+ * cannot be `foreachRDD`, it will confusing py4j
+ */
+ def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
+ val func = new TransformFunction((pfunc))
+ jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
+ }
+
+ /**
+ * convert list of RDD into queue of RDDs, for ssc.queueStream()
+ */
+ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
+ val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
+ rdds.forall(queue.add(_))
+ queue
+ }
+}
+
+/**
+ * Base class for PythonDStream with some common methods
+ */
+private[python] abstract class PythonDStream(
+ parent: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends DStream[Array[Byte]] (parent.ssc) {
+
+ val func = new TransformFunction(pfunc)
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ val asJavaDStream = JavaDStream.fromDStream(this)
+}
+
+/**
+ * Transformed DStream in Python.
+ */
+private[python] class PythonTransformedDStream (
+ parent: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends PythonDStream(parent, pfunc) {
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val rdd = parent.getOrCompute(validTime)
+ if (rdd.isDefined) {
+ func(rdd, validTime)
+ } else {
+ None
+ }
+ }
+}
+
+/**
+ * Transformed from two DStreams in Python.
+ */
+private[python] class PythonTransformed2DStream(
+ parent: DStream[_],
+ parent2: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends DStream[Array[Byte]] (parent.ssc) {
+
+ val func = new TransformFunction(pfunc)
+
+ override def dependencies = List(parent, parent2)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val empty: RDD[_] = ssc.sparkContext.emptyRDD
+ val rdd1 = parent.getOrCompute(validTime).getOrElse(empty)
+ val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty)
+ func(Some(rdd1), Some(rdd2), validTime)
+ }
+
+ val asJavaDStream = JavaDStream.fromDStream(this)
+}
+
+/**
+ * similar to StateDStream
+ */
+private[python] class PythonStateDStream(
+ parent: DStream[Array[Byte]],
+ @transient reduceFunc: PythonTransformFunction)
+ extends PythonDStream(parent, reduceFunc) {
+
+ super.persist(StorageLevel.MEMORY_ONLY)
+ override val mustCheckpoint = true
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val lastState = getOrCompute(validTime - slideDuration)
+ val rdd = parent.getOrCompute(validTime)
+ if (rdd.isDefined) {
+ func(lastState, rdd, validTime)
+ } else {
+ lastState
+ }
+ }
+}
+
+/**
+ * similar to ReducedWindowedDStream
+ */
+private[python] class PythonReducedWindowedDStream(
+ parent: DStream[Array[Byte]],
+ @transient preduceFunc: PythonTransformFunction,
+ @transient pinvReduceFunc: PythonTransformFunction,
+ _windowDuration: Duration,
+ _slideDuration: Duration)
+ extends PythonDStream(parent, preduceFunc) {
+
+ super.persist(StorageLevel.MEMORY_ONLY)
+ override val mustCheckpoint = true
+
+ val invReduceFunc = new TransformFunction(pinvReduceFunc)
+
+ def windowDuration: Duration = _windowDuration
+ override def slideDuration: Duration = _slideDuration
+ override def parentRememberDuration: Duration = rememberDuration + windowDuration
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val currentTime = validTime
+ val current = new Interval(currentTime - windowDuration, currentTime)
+ val previous = current - slideDuration
+
+ // _____________________________
+ // | previous window _________|___________________
+ // |___________________| current window | --------------> Time
+ // |_____________________________|
+ //
+ // |________ _________| |________ _________|
+ // | |
+ // V V
+ // old RDDs new RDDs
+ //
+ val previousRDD = getOrCompute(previous.endTime)
+
+ // for small window, reduce once will be better than twice
+ if (pinvReduceFunc != null && previousRDD.isDefined
+ && windowDuration >= slideDuration * 5) {
+
+ // subtract the values from old RDDs
+ val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime)
+ val subtracted = if (oldRDDs.size > 0) {
+ invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime)
+ } else {
+ previousRDD
+ }
+
+ // add the RDDs of the reduced values in "new time steps"
+ val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime)
+ if (newRDDs.size > 0) {
+ func(subtracted, Some(ssc.sc.union(newRDDs)), validTime)
+ } else {
+ subtracted
+ }
+ } else {
+ // Get the RDDs of the reduced values in current window
+ val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime)
+ if (currentRDDs.size > 0) {
+ func(None, Some(ssc.sc.union(currentRDDs)), validTime)
+ } else {
+ None
+ }
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 8511390cb1ad5..e5592e52b0d2d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -231,8 +231,7 @@ class CheckpointSuite extends TestSuiteBase {
// failure, are re-processed or not.
test("recovery with file input stream") {
// Set up the streaming context and input streams
- val testDir = Files.createTempDir()
- testDir.deleteOnExit()
+ val testDir = Utils.createTempDir()
var ssc = new StreamingContext(master, framework, Seconds(1))
ssc.checkpoint(checkpointDir)
val fileStream = ssc.textFileStream(testDir.toString)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 6107fcdc447b6..fa04fa326e370 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -96,8 +96,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
// Set up the streaming context and input streams
- val testDir = Files.createTempDir()
- testDir.deleteOnExit()
+ val testDir = Utils.createTempDir()
val ssc = new StreamingContext(conf, batchDuration)
val fileStream = ssc.textFileStream(testDir.toString)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
index c53c01706083a..5dbb7232009eb 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
@@ -352,8 +352,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long)
extends Thread with Logging {
override def run() {
- val localTestDir = Files.createTempDir()
- localTestDir.deleteOnExit()
+ val localTestDir = Utils.createTempDir()
var fs = testDir.getFileSystem(new Configuration())
val maxTries = 3
try {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 759baacaa4308..9327ff4822699 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -24,12 +24,12 @@ import scala.collection.mutable.SynchronizedBuffer
import scala.reflect.ClassTag
import org.scalatest.{BeforeAndAfter, FunSuite}
-import com.google.common.io.Files
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
import org.apache.spark.streaming.util.ManualClock
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
/**
* This is a input stream just for the testsuites. This is equivalent to a checkpointable,
@@ -120,9 +120,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Directory where the checkpoint data will be saved
lazy val checkpointDir = {
- val dir = Files.createTempDir()
+ val dir = Utils.createTempDir()
logDebug(s"checkpointDir: $dir")
- dir.deleteOnExit()
dir.toString
}
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 5a20532315e59..5c7bca4541222 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -122,7 +122,7 @@ private[spark] class Client(
* ApplicationReport#getClientToken is renamed `getClientToAMToken` in the stable API.
*/
override def getClientToken(report: ApplicationReport): String =
- Option(report.getClientToken).getOrElse("")
+ Option(report.getClientToken).map(_.toString).getOrElse("")
}
object Client {
diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala
index 9bd916100dd2c..17b79ae1d82c4 100644
--- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala
+++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala
@@ -20,13 +20,10 @@ package org.apache.spark.deploy.yarn
import java.io.File
import java.net.URI
-import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.MRJobConfig
-import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
-import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.mockito.Matchers._
@@ -117,7 +114,7 @@ class ClientBaseSuite extends FunSuite with Matchers {
doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]),
any(classOf[Path]), anyShort(), anyBoolean())
- val tempDir = Files.createTempDir()
+ val tempDir = Utils.createTempDir()
try {
client.prepareLocalResources(tempDir.getAbsolutePath())
sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER))