org.scalatest
scalatest_${scala.binary.version}
diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
index 9f13b39909481..840a1bd93bfbb 100644
--- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
+++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
@@ -23,17 +23,18 @@
* Expose some commonly useful storage level constants.
*/
public class StorageLevels {
- public static final StorageLevel NONE = create(false, false, false, 1);
- public static final StorageLevel DISK_ONLY = create(true, false, false, 1);
- public static final StorageLevel DISK_ONLY_2 = create(true, false, false, 2);
- public static final StorageLevel MEMORY_ONLY = create(false, true, true, 1);
- public static final StorageLevel MEMORY_ONLY_2 = create(false, true, true, 2);
- public static final StorageLevel MEMORY_ONLY_SER = create(false, true, false, 1);
- public static final StorageLevel MEMORY_ONLY_SER_2 = create(false, true, false, 2);
- public static final StorageLevel MEMORY_AND_DISK = create(true, true, true, 1);
- public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, true, 2);
- public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, 1);
- public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, 2);
+ public static final StorageLevel NONE = create(false, false, false, false, 1);
+ public static final StorageLevel DISK_ONLY = create(true, false, false, false, 1);
+ public static final StorageLevel DISK_ONLY_2 = create(true, false, false, false, 2);
+ public static final StorageLevel MEMORY_ONLY = create(false, true, false, true, 1);
+ public static final StorageLevel MEMORY_ONLY_2 = create(false, true, false, true, 2);
+ public static final StorageLevel MEMORY_ONLY_SER = create(false, true, false, false, 1);
+ public static final StorageLevel MEMORY_ONLY_SER_2 = create(false, true, false, false, 2);
+ public static final StorageLevel MEMORY_AND_DISK = create(true, true, false, true, 1);
+ public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, false, true, 2);
+ public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, false, 1);
+ public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, false, 2);
+ public static final StorageLevel OFF_HEAP = create(false, false, true, false, 1);
/**
* Create a new StorageLevel object.
@@ -42,7 +43,26 @@ public class StorageLevels {
* @param deserialized saved as deserialized objects, if true
* @param replication replication factor
*/
- public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
- return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
+ @Deprecated
+ public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized,
+ int replication) {
+ return StorageLevel.apply(useDisk, useMemory, false, deserialized, replication);
+ }
+
+ /**
+ * Create a new StorageLevel object.
+ * @param useDisk saved to disk, if true
+ * @param useMemory saved to memory, if true
+ * @param useOffHeap saved to Tachyon, if true
+ * @param deserialized saved as deserialized objects, if true
+ * @param replication replication factor
+ */
+ public static StorageLevel create(
+ boolean useDisk,
+ boolean useMemory,
+ boolean useOffHeap,
+ boolean deserialized,
+ int replication) {
+ return StorageLevel.apply(useDisk, useMemory, useOffHeap, deserialized, replication);
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 4dd298177f07d..e5ebd350eeced 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -19,14 +19,13 @@ package org.apache.spark
import java.io._
import java.net.URI
-import java.util.{Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-
+import java.util.{Properties, UUID}
+import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.reflect.{ClassTag, classTag}
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
@@ -35,7 +34,9 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import org.apache.spark.input.WholeTextFileInputFormat
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
@@ -128,6 +129,11 @@ class SparkContext(
val master = conf.get("spark.master")
val appName = conf.get("spark.app.name")
+ // Generate the random name for a temp folder in Tachyon
+ // Add a timestamp as the suffix here to make it more safe
+ val tachyonFolderName = "spark-" + randomUUID.toString()
+ conf.set("spark.tachyonStore.folderName", tachyonFolderName)
+
val isLocal = (master == "local" || master.startsWith("local["))
if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
@@ -230,7 +236,7 @@ class SparkContext(
postEnvironmentUpdate()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
- val hadoopConfiguration = {
+ val hadoopConfiguration: Configuration = {
val env = SparkEnv.get
val hadoopConf = SparkHadoopUtil.get.newConfiguration()
// Explicitly check for S3 environment variables
@@ -370,6 +376,39 @@ class SparkContext(
minSplits).map(pair => pair._2.toString)
}
+ /**
+ * Read a directory of text files from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI. Each file is read as a single record and returned in a
+ * key-value pair, where the key is the path of each file, the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do `val rdd = sparkContext.wholeTextFile("hdfs://a-hdfs-path")`,
+ *
+ *
then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred, as each file will be loaded fully in memory.
+ */
+ def wholeTextFiles(path: String): RDD[(String, String)] = {
+ newAPIHadoopFile(
+ path,
+ classOf[WholeTextFileInputFormat],
+ classOf[String],
+ classOf[String])
+ }
+
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
* necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
@@ -630,7 +669,7 @@ class SparkContext(
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
- (initialValue: R) = {
+ (initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
@@ -640,7 +679,7 @@ class SparkContext(
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
+ def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
* Add a file to be downloaded with this Spark job on every node.
@@ -692,10 +731,6 @@ class SparkContext(
*/
def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
- def getStageInfo: Map[Stage, StageInfo] = {
- dagScheduler.stageToInfos
- }
-
/**
* Return information about blocks stored in all of the slaves
*/
@@ -1126,7 +1161,7 @@ object SparkContext extends Logging {
implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
- rdd: RDD[(K, V)]) =
+ rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag](
@@ -1163,27 +1198,33 @@ object SparkContext extends Logging {
}
// Helper objects for converting common types to Writable
- private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = {
+ private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T)
+ : WritableConverter[T] = {
val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
}
- implicit def intWritableConverter() = simpleWritableConverter[Int, IntWritable](_.get)
+ implicit def intWritableConverter(): WritableConverter[Int] =
+ simpleWritableConverter[Int, IntWritable](_.get)
- implicit def longWritableConverter() = simpleWritableConverter[Long, LongWritable](_.get)
+ implicit def longWritableConverter(): WritableConverter[Long] =
+ simpleWritableConverter[Long, LongWritable](_.get)
- implicit def doubleWritableConverter() = simpleWritableConverter[Double, DoubleWritable](_.get)
+ implicit def doubleWritableConverter(): WritableConverter[Double] =
+ simpleWritableConverter[Double, DoubleWritable](_.get)
- implicit def floatWritableConverter() = simpleWritableConverter[Float, FloatWritable](_.get)
+ implicit def floatWritableConverter(): WritableConverter[Float] =
+ simpleWritableConverter[Float, FloatWritable](_.get)
- implicit def booleanWritableConverter() =
+ implicit def booleanWritableConverter(): WritableConverter[Boolean] =
simpleWritableConverter[Boolean, BooleanWritable](_.get)
- implicit def bytesWritableConverter() = {
+ implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
}
- implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
+ implicit def stringWritableConverter(): WritableConverter[String] =
+ simpleWritableConverter[String, Text](_.toString)
implicit def writableWritableConverter[T <: Writable]() =
new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
@@ -1244,8 +1285,8 @@ object SparkContext extends Logging {
/** Creates a task scheduler based on a given master URL. Extracted for testing. */
private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = {
- // Regular expression used for local[N] master format
- val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ // Regular expression used for local[N] and local[*] master formats
+ val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
@@ -1268,8 +1309,11 @@ object SparkContext extends Logging {
scheduler
case LOCAL_N_REGEX(threads) =>
+ def localCpuCount = Runtime.getRuntime.availableProcessors()
+ // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
+ val threadCount = if (threads == "*") localCpuCount else threads.toInt
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
- val backend = new LocalBackend(scheduler, threads.toInt)
+ val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
scheduler
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index a1af63fa4a391..5ceac28fe7afb 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -81,7 +81,7 @@ class SparkEnv private[spark] (
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
// UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
- //actorSystem.awaitTermination()
+ // actorSystem.awaitTermination()
}
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index ddac553304233..6e8ec8e0c7629 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -17,7 +17,8 @@
package org.apache.spark.api.java
-import java.util.{Comparator, List => JList}
+import java.util.{Comparator, Iterator => JIterator, List => JList}
+import java.lang.{Iterable => JIterable}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -280,6 +281,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
+ /**
+ * Return an iterator that contains all of the elements in this RDD.
+ *
+ * The iterator will consume as much memory as the largest partition in this RDD.
+ */
+ def toLocalIterator(): JIterator[T] = {
+ import scala.collection.JavaConversions._
+ rdd.toLocalIterator
+ }
+
+
/**
* Return an array that contains all of the elements in this RDD.
* @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead
@@ -391,19 +403,24 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Save this RDD as a text file, using string representations of elements.
*/
- def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+ def saveAsTextFile(path: String): Unit = {
+ rdd.saveAsTextFile(path)
+ }
/**
* Save this RDD as a compressed text file, using string representations of elements.
*/
- def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]): Unit = {
rdd.saveAsTextFile(path, codec)
+ }
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
- def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
+ def saveAsObjectFile(path: String): Unit = {
+ rdd.saveAsObjectFile(path)
+ }
/**
* Creates tuples of the elements in this RDD by applying `f`.
@@ -420,7 +437,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
- def checkpoint() = rdd.checkpoint()
+ def checkpoint(): Unit = {
+ rdd.checkpoint()
+ }
/**
* Return whether this RDD has been checkpointed or not
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 35508b6e5acba..a2855d4db1d2e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -154,6 +154,34 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
*/
def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits)
+ /**
+ * Read a directory of text files from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI. Each file is read as a single record and returned in a
+ * key-value pair, where the key is the path of each file, the value is the content of each file.
+ *
+ *
For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred, as each file will be loaded fully in memory.
+ */
+ def wholeTextFiles(path: String): JavaPairRDD[String, String] =
+ new JavaPairRDD(sc.wholeTextFiles(path))
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
@@ -463,7 +491,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
sc.setCheckpointDir(dir)
}
- def getCheckpointDir = JavaUtils.optionToOptional(sc.getCheckpointDir)
+ def getCheckpointDir: Optional[String] = JavaUtils.optionToOptional(sc.getCheckpointDir)
protected def checkpointFile[T](path: String): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b67286a4e3b75..32f1100406d74 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
import java.io._
import java.net._
+import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
@@ -206,6 +207,7 @@ private object SpecialLengths {
}
private[spark] object PythonRDD {
+ val UTF8 = Charset.forName("UTF-8")
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
@@ -266,7 +268,7 @@ private[spark] object PythonRDD {
}
def writeUTF(str: String, dataOut: DataOutputStream) {
- val bytes = str.getBytes("UTF-8")
+ val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
@@ -286,7 +288,7 @@ private[spark] object PythonRDD {
private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
- override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
+ override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 3cd71213769b7..2595c15104e87 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -167,7 +167,7 @@ extends Logging {
private var initialized = false
private var conf: SparkConf = null
def initialize(_isDriver: Boolean, conf: SparkConf) {
- TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests
+ TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
synchronized {
if (!initialized) {
initialized = true
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index d9e3035e1ab59..8fd2c7e95b966 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -128,6 +128,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
*/
object Client {
def main(args: Array[String]) {
+ println("WARNING: This client is deprecated and will be removed in a future version of Spark.")
+ println("Use ./bin/spark-submit with \"--master spark://host:port\"")
+
val conf = new SparkConf()
val driverArgs = new ClientArguments(args)
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 00f5cd54ad650..c07838f798799 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -112,5 +112,5 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
- def isValidJarUrl(s: String) = s.matches("(.+):(.+)jar")
+ def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar")
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 83ce14a0a806a..a7368f9f3dfbe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -86,6 +86,10 @@ private[deploy] object DeployMessages {
case class KillDriver(driverId: String) extends DeployMessage
+ // Worker internal
+
+ case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
+
// AppClient to Master
case class RegisterApplication(appDescription: ApplicationDescription)
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index a73b459c3cea1..9a7a113c95715 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -66,9 +66,9 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
// TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
// This is unfortunate, but for now we just comment it out.
workerActorSystems.foreach(_.shutdown())
- //workerActorSystems.foreach(_.awaitTermination())
+ // workerActorSystems.foreach(_.awaitTermination())
masterActorSystems.foreach(_.shutdown())
- //masterActorSystems.foreach(_.awaitTermination())
+ // masterActorSystems.foreach(_.awaitTermination())
masterActorSystems.clear()
workerActorSystems.clear()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index d2d8d6d662d55..9bdbfb33bf54f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -32,7 +32,7 @@ import scala.collection.JavaConversions._
* Contains util methods to interact with Hadoop from Spark.
*/
class SparkHadoopUtil {
- val conf = newConfiguration()
+ val conf: Configuration = newConfiguration()
UserGroupInformation.setConfiguration(conf)
def runAsUser(user: String)(func: () => Unit) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
new file mode 100644
index 0000000000000..e05fbfe321495
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io.{PrintStream, File}
+import java.net.URL
+
+import org.apache.spark.executor.ExecutorURLClassLoader
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+/**
+ * Scala code behind the spark-submit script. The script handles setting up the classpath with
+ * relevant Spark dependencies and provides a layer over the different cluster managers and deploy
+ * modes that Spark supports.
+ */
+object SparkSubmit {
+ private val YARN = 1
+ private val STANDALONE = 2
+ private val MESOS = 4
+ private val LOCAL = 8
+ private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL
+
+ private var clusterManager: Int = LOCAL
+
+ def main(args: Array[String]) {
+ val appArgs = new SparkSubmitArguments(args)
+ if (appArgs.verbose) {
+ printStream.println(appArgs)
+ }
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
+ }
+
+ // Exposed for testing
+ private[spark] var printStream: PrintStream = System.err
+ private[spark] var exitFn: () => Unit = () => System.exit(-1)
+
+ private[spark] def printErrorAndExit(str: String) = {
+ printStream.println("error: " + str)
+ printStream.println("run with --help for more information or --verbose for debugging output")
+ exitFn()
+ }
+ private[spark] def printWarning(str: String) = printStream.println("warning: " + str)
+
+ /**
+ * @return
+ * a tuple containing the arguments for the child, a list of classpath
+ * entries for the child, and the main class for the child
+ */
+ private[spark] def createLaunchEnv(appArgs: SparkSubmitArguments): (ArrayBuffer[String],
+ ArrayBuffer[String], Map[String, String], String) = {
+ if (appArgs.master.startsWith("local")) {
+ clusterManager = LOCAL
+ } else if (appArgs.master.startsWith("yarn")) {
+ clusterManager = YARN
+ } else if (appArgs.master.startsWith("spark")) {
+ clusterManager = STANDALONE
+ } else if (appArgs.master.startsWith("mesos")) {
+ clusterManager = MESOS
+ } else {
+ printErrorAndExit("master must start with yarn, mesos, spark, or local")
+ }
+
+ // Because "yarn-cluster" and "yarn-client" encapsulate both the master
+ // and deploy mode, we have some logic to infer the master and deploy mode
+ // from each other if only one is specified, or exit early if they are at odds.
+ if (appArgs.deployMode == null &&
+ (appArgs.master == "yarn-standalone" || appArgs.master == "yarn-cluster")) {
+ appArgs.deployMode = "cluster"
+ }
+ if (appArgs.deployMode == "cluster" && appArgs.master == "yarn-client") {
+ printErrorAndExit("Deploy mode \"cluster\" and master \"yarn-client\" are not compatible")
+ }
+ if (appArgs.deployMode == "client" &&
+ (appArgs.master == "yarn-standalone" || appArgs.master == "yarn-cluster")) {
+ printErrorAndExit("Deploy mode \"client\" and master \"" + appArgs.master
+ + "\" are not compatible")
+ }
+ if (appArgs.deployMode == "cluster" && appArgs.master.startsWith("yarn")) {
+ appArgs.master = "yarn-cluster"
+ }
+ if (appArgs.deployMode != "cluster" && appArgs.master.startsWith("yarn")) {
+ appArgs.master = "yarn-client"
+ }
+
+ val deployOnCluster = Option(appArgs.deployMode).getOrElse("client") == "cluster"
+
+ val childClasspath = new ArrayBuffer[String]()
+ val childArgs = new ArrayBuffer[String]()
+ val sysProps = new HashMap[String, String]()
+ var childMainClass = ""
+
+ if (clusterManager == MESOS && deployOnCluster) {
+ printErrorAndExit("Mesos does not support running the driver on the cluster")
+ }
+
+ if (!deployOnCluster) {
+ childMainClass = appArgs.mainClass
+ childClasspath += appArgs.primaryResource
+ } else if (clusterManager == YARN) {
+ childMainClass = "org.apache.spark.deploy.yarn.Client"
+ childArgs += ("--jar", appArgs.primaryResource)
+ childArgs += ("--class", appArgs.mainClass)
+ }
+
+ val options = List[OptionAssigner](
+ new OptionAssigner(appArgs.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"),
+ new OptionAssigner(appArgs.driverMemory, YARN, true, clOption = "--driver-memory"),
+ new OptionAssigner(appArgs.name, YARN, true, clOption = "--name"),
+ new OptionAssigner(appArgs.queue, YARN, true, clOption = "--queue"),
+ new OptionAssigner(appArgs.queue, YARN, false, sysProp = "spark.yarn.queue"),
+ new OptionAssigner(appArgs.numExecutors, YARN, true, clOption = "--num-executors"),
+ new OptionAssigner(appArgs.numExecutors, YARN, false, sysProp = "spark.executor.instances"),
+ new OptionAssigner(appArgs.executorMemory, YARN, true, clOption = "--executor-memory"),
+ new OptionAssigner(appArgs.executorMemory, STANDALONE | MESOS | YARN, false,
+ sysProp = "spark.executor.memory"),
+ new OptionAssigner(appArgs.driverMemory, STANDALONE, true, clOption = "--memory"),
+ new OptionAssigner(appArgs.driverCores, STANDALONE, true, clOption = "--cores"),
+ new OptionAssigner(appArgs.executorCores, YARN, true, clOption = "--executor-cores"),
+ new OptionAssigner(appArgs.executorCores, YARN, false, sysProp = "spark.executor.cores"),
+ new OptionAssigner(appArgs.totalExecutorCores, STANDALONE | MESOS, false,
+ sysProp = "spark.cores.max"),
+ new OptionAssigner(appArgs.files, YARN, false, sysProp = "spark.yarn.dist.files"),
+ new OptionAssigner(appArgs.files, YARN, true, clOption = "--files"),
+ new OptionAssigner(appArgs.archives, YARN, false, sysProp = "spark.yarn.dist.archives"),
+ new OptionAssigner(appArgs.archives, YARN, true, clOption = "--archives"),
+ new OptionAssigner(appArgs.jars, YARN, true, clOption = "--addJars")
+ )
+
+ // more jars
+ if (appArgs.jars != null && !deployOnCluster) {
+ for (jar <- appArgs.jars.split(",")) {
+ childClasspath += jar
+ }
+ }
+
+ for (opt <- options) {
+ if (opt.value != null && deployOnCluster == opt.deployOnCluster &&
+ (clusterManager & opt.clusterManager) != 0) {
+ if (opt.clOption != null) {
+ childArgs += (opt.clOption, opt.value)
+ } else if (opt.sysProp != null) {
+ sysProps.put(opt.sysProp, opt.value)
+ }
+ }
+ }
+
+ if (deployOnCluster && clusterManager == STANDALONE) {
+ if (appArgs.supervise) {
+ childArgs += "--supervise"
+ }
+
+ childMainClass = "org.apache.spark.deploy.Client"
+ childArgs += "launch"
+ childArgs += (appArgs.master, appArgs.primaryResource, appArgs.mainClass)
+ }
+
+ // args
+ if (appArgs.childArgs != null) {
+ if (!deployOnCluster || clusterManager == STANDALONE) {
+ childArgs ++= appArgs.childArgs
+ } else if (clusterManager == YARN) {
+ for (arg <- appArgs.childArgs) {
+ childArgs += ("--args", arg)
+ }
+ }
+ }
+
+ (childArgs, childClasspath, sysProps, childMainClass)
+ }
+
+ private def launch(childArgs: ArrayBuffer[String], childClasspath: ArrayBuffer[String],
+ sysProps: Map[String, String], childMainClass: String, verbose: Boolean = false) {
+
+ if (verbose) {
+ System.err.println(s"Main class:\n$childMainClass")
+ System.err.println(s"Arguments:\n${childArgs.mkString("\n")}")
+ System.err.println(s"System properties:\n${sysProps.mkString("\n")}")
+ System.err.println(s"Classpath elements:\n${childClasspath.mkString("\n")}")
+ System.err.println("\n")
+ }
+
+ val loader = new ExecutorURLClassLoader(new Array[URL](0),
+ Thread.currentThread.getContextClassLoader)
+ Thread.currentThread.setContextClassLoader(loader)
+
+ for (jar <- childClasspath) {
+ addJarToClasspath(jar, loader)
+ }
+
+ for ((key, value) <- sysProps) {
+ System.setProperty(key, value)
+ }
+
+ val mainClass = Class.forName(childMainClass, true, loader)
+ val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass)
+ mainMethod.invoke(null, childArgs.toArray)
+ }
+
+ private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) {
+ val localJarFile = new File(localJar)
+ if (!localJarFile.exists()) {
+ printWarning(s"Jar $localJar does not exist, skipping.")
+ }
+
+ val url = localJarFile.getAbsoluteFile.toURI.toURL
+ loader.addURL(url)
+ }
+}
+
+private[spark] class OptionAssigner(val value: String,
+ val clusterManager: Int,
+ val deployOnCluster: Boolean,
+ val clOption: String = null,
+ val sysProp: String = null
+) { }
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
new file mode 100644
index 0000000000000..834b3df2f164b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Parses and encapsulates arguments from the spark-submit script.
+ */
+private[spark] class SparkSubmitArguments(args: Array[String]) {
+ var master: String = "local"
+ var deployMode: String = null
+ var executorMemory: String = null
+ var executorCores: String = null
+ var totalExecutorCores: String = null
+ var driverMemory: String = null
+ var driverCores: String = null
+ var supervise: Boolean = false
+ var queue: String = null
+ var numExecutors: String = null
+ var files: String = null
+ var archives: String = null
+ var mainClass: String = null
+ var primaryResource: String = null
+ var name: String = null
+ var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
+ var jars: String = null
+ var verbose: Boolean = false
+
+ loadEnvVars()
+ parseOpts(args.toList)
+
+ // Sanity checks
+ if (args.length == 0) printUsageAndExit(-1)
+ if (primaryResource == null) SparkSubmit.printErrorAndExit("Must specify a primary resource")
+ if (mainClass == null) SparkSubmit.printErrorAndExit("Must specify a main class with --class")
+
+ override def toString = {
+ s"""Parsed arguments:
+ | master $master
+ | deployMode $deployMode
+ | executorMemory $executorMemory
+ | executorCores $executorCores
+ | totalExecutorCores $totalExecutorCores
+ | driverMemory $driverMemory
+ | drivercores $driverCores
+ | supervise $supervise
+ | queue $queue
+ | numExecutors $numExecutors
+ | files $files
+ | archives $archives
+ | mainClass $mainClass
+ | primaryResource $primaryResource
+ | name $name
+ | childArgs [${childArgs.mkString(" ")}]
+ | jars $jars
+ | verbose $verbose
+ """.stripMargin
+ }
+
+ private def loadEnvVars() {
+ Option(System.getenv("MASTER")).map(master = _)
+ Option(System.getenv("DEPLOY_MODE")).map(deployMode = _)
+ }
+
+ private def parseOpts(opts: List[String]): Unit = opts match {
+ case ("--name") :: value :: tail =>
+ name = value
+ parseOpts(tail)
+
+ case ("--master") :: value :: tail =>
+ master = value
+ parseOpts(tail)
+
+ case ("--class") :: value :: tail =>
+ mainClass = value
+ parseOpts(tail)
+
+ case ("--deploy-mode") :: value :: tail =>
+ if (value != "client" && value != "cluster") {
+ SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"")
+ }
+ deployMode = value
+ parseOpts(tail)
+
+ case ("--num-executors") :: value :: tail =>
+ numExecutors = value
+ parseOpts(tail)
+
+ case ("--total-executor-cores") :: value :: tail =>
+ totalExecutorCores = value
+ parseOpts(tail)
+
+ case ("--executor-cores") :: value :: tail =>
+ executorCores = value
+ parseOpts(tail)
+
+ case ("--executor-memory") :: value :: tail =>
+ executorMemory = value
+ parseOpts(tail)
+
+ case ("--driver-memory") :: value :: tail =>
+ driverMemory = value
+ parseOpts(tail)
+
+ case ("--driver-cores") :: value :: tail =>
+ driverCores = value
+ parseOpts(tail)
+
+ case ("--supervise") :: tail =>
+ supervise = true
+ parseOpts(tail)
+
+ case ("--queue") :: value :: tail =>
+ queue = value
+ parseOpts(tail)
+
+ case ("--files") :: value :: tail =>
+ files = value
+ parseOpts(tail)
+
+ case ("--archives") :: value :: tail =>
+ archives = value
+ parseOpts(tail)
+
+ case ("--arg") :: value :: tail =>
+ childArgs += value
+ parseOpts(tail)
+
+ case ("--jars") :: value :: tail =>
+ jars = value
+ parseOpts(tail)
+
+ case ("--help" | "-h") :: tail =>
+ printUsageAndExit(0)
+
+ case ("--verbose" | "-v") :: tail =>
+ verbose = true
+ parseOpts(tail)
+
+ case value :: tail =>
+ if (primaryResource != null) {
+ val error = s"Found two conflicting resources, $value and $primaryResource." +
+ " Expecting only one resource."
+ SparkSubmit.printErrorAndExit(error)
+ }
+ primaryResource = value
+ parseOpts(tail)
+
+ case Nil =>
+ }
+
+ private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ val outStream = SparkSubmit.printStream
+ if (unknownParam != null) {
+ outStream.println("Unknown/unsupported param " + unknownParam)
+ }
+ outStream.println(
+ """Usage: spark-submit [options]
+ |Options:
+ | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
+ | --deploy-mode DEPLOY_MODE Mode to deploy the app in, either 'client' or 'cluster'.
+ | --class CLASS_NAME Name of your app's main class (required for Java apps).
+ | --arg ARG Argument to be passed to your application's main class. This
+ | option can be specified multiple times for multiple args.
+ | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M).
+ | --name NAME The name of your application (Default: 'Spark').
+ | --jars JARS A comma-separated list of local jars to include on the
+ | driver classpath and that SparkContext.addJar will work
+ | with. Doesn't work on standalone with 'cluster' deploy mode.
+ |
+ | Spark standalone with cluster deploy mode only:
+ | --driver-cores NUM Cores for driver (Default: 1).
+ | --supervise If given, restarts the driver on failure.
+ |
+ | Spark standalone and Mesos only:
+ | --total-executor-cores NUM Total cores for all executors.
+ |
+ | YARN-only:
+ | --executor-cores NUM Number of cores per executor (Default: 1).
+ | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G).
+ | --queue QUEUE_NAME The YARN queue to submit to (Default: 'default').
+ | --num-executors NUM Number of executors to (Default: 2).
+ | --files FILES Comma separated list of files to be placed in the working dir
+ | of each executor.
+ | --archives ARCHIVES Comma separated list of archives to be extracted into the
+ | working dir of each executor.""".stripMargin
+ )
+ SparkSubmit.exitFn()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index a730fe1f599af..4433a2ec29be6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -30,7 +30,7 @@ import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
* [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
*/
private[spark] trait LeaderElectionAgent extends Actor {
- //TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
+ // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
val masterActor: ActorRef
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 5413ff671ad8d..834dfedee52ce 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -20,6 +20,7 @@ package org.apache.spark.deploy.master
import scala.collection.JavaConversions._
import akka.serialization.Serialization
+import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
@@ -29,7 +30,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
with Logging
{
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
- val zk = SparkCuratorUtil.newClient(conf)
+ val zk: CuratorFramework = SparkCuratorUtil.newClient(conf)
SparkCuratorUtil.mkdir(zk, WORKING_DIR)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 8a71ddda4cb5e..bf5a8d09dd2df 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -64,6 +64,12 @@ private[spark] class Worker(
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
+ val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", true)
+ // How often worker will clean up old app folders
+ val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000
+ // TTL for app folders/data; after TTL expires it will be cleaned up
+ val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600)
+
// Index into masterUrls that we're currently trying to register with.
var masterIndex = 0
@@ -179,12 +185,28 @@ private[spark] class Worker(
registered = true
changeMaster(masterUrl, masterWebUiUrl)
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
+ if (CLEANUP_ENABLED) {
+ context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis,
+ CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup)
+ }
case SendHeartbeat =>
masterLock.synchronized {
if (connected) { master ! Heartbeat(workerId) }
}
+ case WorkDirCleanup =>
+ // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor
+ val cleanupFuture = concurrent.future {
+ logInfo("Cleaning up oldest application directories in " + workDir + " ...")
+ Utils.findOldFiles(workDir, APP_DATA_RETENTION_SECS)
+ .foreach(Utils.deleteRecursively)
+ }
+ cleanupFuture onFailure {
+ case e: Throwable =>
+ logError("App dir cleanup failed: " + e.getMessage, e)
+ }
+
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
changeMaster(masterUrl, masterWebUiUrl)
@@ -331,7 +353,6 @@ private[spark] class Worker(
}
private[spark] object Worker {
-
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 3486092a140fb..16887d8892b31 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -53,7 +53,8 @@ private[spark] class CoarseGrainedExecutorBackend(
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
// Make this host instead of hostPort ?
- executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
+ executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties,
+ false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -105,7 +106,8 @@ private[spark] object CoarseGrainedExecutorBackend {
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
- Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
+ Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId,
+ sparkHostPort, cores),
name = "Executor")
workerUrl.foreach{ url =>
actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 8fe9b848ba145..aecb069e4202b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -112,11 +112,10 @@ private[spark] class Executor(
}
}
- // Create our ClassLoader and set it on this thread
+ // Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
- Thread.currentThread.setContextClassLoader(replClassLoader)
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
@@ -276,7 +275,6 @@ private[spark] class Executor(
// have left some weird state around depending on when the exception was thrown, but on
// the other hand, maybe we could detect that when future tasks fail and exit then.
logError("Exception in task ID " + taskId, t)
- //System.exit(1)
}
} finally {
// TODO: Unregister shuffle memory only for ResultTask
@@ -294,7 +292,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
- val loader = this.getClass.getClassLoader
+ val loader = Thread.currentThread().getContextClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
index 210f3dbeebaca..ceff3a067d72a 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -41,6 +41,12 @@ object ExecutorExitCode {
/** DiskStore failed to create a local temporary directory after many attempts. */
val DISK_STORE_FAILED_TO_CREATE_DIR = 53
+ /** TachyonStore failed to initialize after many attempts. */
+ val TACHYON_STORE_FAILED_TO_INITIALIZE = 54
+
+ /** TachyonStore failed to create a local temporary directory after many attempts. */
+ val TACHYON_STORE_FAILED_TO_CREATE_DIR = 55
+
def explainExitCode(exitCode: Int): String = {
exitCode match {
case UNCAUGHT_EXCEPTION => "Uncaught exception"
@@ -48,6 +54,9 @@ object ExecutorExitCode {
case OOM => "OutOfMemoryError"
case DISK_STORE_FAILED_TO_CREATE_DIR =>
"Failed to create local directory (bad spark.local.dir?)"
+ case TACHYON_STORE_FAILED_TO_INITIALIZE => "TachyonStore failed to initialize."
+ case TACHYON_STORE_FAILED_TO_CREATE_DIR =>
+ "TachyonStore failed to create a local temporary directory."
case _ =>
"Unknown executor exit code (" + exitCode + ")" + (
if (exitCode > 128) {
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
new file mode 100644
index 0000000000000..4887fb6b84eb2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.InputSplit
+import org.apache.hadoop.mapreduce.JobContext
+import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
+import org.apache.hadoop.mapreduce.RecordReader
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader
+import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+
+/**
+ * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for
+ * reading whole text files. Each file is read as key-value pair, where the key is the file path and
+ * the value is the entire content of file.
+ */
+
+private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] {
+ override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+
+ override def createRecordReader(
+ split: InputSplit,
+ context: TaskAttemptContext): RecordReader[String, String] = {
+
+ new CombineFileRecordReader[String, String](
+ split.asInstanceOf[CombineFileSplit],
+ context,
+ classOf[WholeTextFileRecordReader])
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
new file mode 100644
index 0000000000000..c3dabd2e79995
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import com.google.common.io.{ByteStreams, Closeables}
+
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapreduce.InputSplit
+import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+import org.apache.hadoop.mapreduce.RecordReader
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+
+/**
+ * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
+ * out in a key-value pair, where the key is the file path and the value is the entire content of
+ * the file.
+ */
+private[spark] class WholeTextFileRecordReader(
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends RecordReader[String, String] {
+
+ private val path = split.getPath(index)
+ private val fs = path.getFileSystem(context.getConfiguration)
+
+ // True means the current file has been processed, then skip it.
+ private var processed = false
+
+ private val key = path.toString
+ private var value: String = null
+
+ override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+
+ override def close() = {}
+
+ override def getProgress = if (processed) 1.0f else 0.0f
+
+ override def getCurrentKey = key
+
+ override def getCurrentValue = value
+
+ override def nextKeyValue = {
+ if (!processed) {
+ val fileIn = fs.open(path)
+ val innerBuffer = ByteStreams.toByteArray(fileIn)
+
+ value = new Text(innerBuffer).toString
+ Closeables.close(fileIn, false)
+
+ processed = true
+ true
+ } else {
+ false
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
index 6883a54494598..3e3e18c3537d0 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -42,7 +42,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
}
def initialize() {
- //Add default properties in case there's no properties file
+ // Add default properties in case there's no properties file
setDefaultProperties(properties)
// If spark.metrics.conf is not set, try to get file in class path
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
index 4d2ffc54d8983..64eac73605388 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -38,7 +38,7 @@ class ConsoleSink(val property: Properties, val registry: MetricRegistry,
case None => CONSOLE_DEFAULT_PERIOD
}
- val pollUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match {
+ val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match {
case Some(s) => TimeUnit.valueOf(s.toUpperCase())
case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
index 319f40815d65f..544848d4150b6 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -41,7 +41,7 @@ class CsvSink(val property: Properties, val registry: MetricRegistry,
case None => CSV_DEFAULT_PERIOD
}
- val pollUnit = Option(property.getProperty(CSV_KEY_UNIT)) match {
+ val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match {
case Some(s) => TimeUnit.valueOf(s.toUpperCase())
case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index 0ffdf3846dc4a..7f0a2fd16fa99 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -39,7 +39,7 @@ class GraphiteSink(val property: Properties, val registry: MetricRegistry,
val GRAPHITE_KEY_UNIT = "unit"
val GRAPHITE_KEY_PREFIX = "prefix"
- def propertyToOption(prop: String) = Option(property.getProperty(prop))
+ def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop))
if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) {
throw new Exception("Graphite sink requires 'host' property.")
@@ -57,7 +57,7 @@ class GraphiteSink(val property: Properties, val registry: MetricRegistry,
case None => GRAPHITE_DEFAULT_PERIOD
}
- val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
+ val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
case Some(s) => TimeUnit.valueOf(s.toUpperCase())
case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT)
}
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index 8fd9c2b87d256..2f7576c53b482 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -48,7 +48,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
channel.socket.setTcpNoDelay(true)
channel.socket.setReuseAddress(true)
channel.socket.setKeepAlive(true)
- /*channel.socket.setReceiveBufferSize(32768) */
+ /* channel.socket.setReceiveBufferSize(32768) */
@volatile private var closed = false
var onCloseCallback: Connection => Unit = null
@@ -206,12 +206,12 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
private class Outbox {
val messages = new Queue[Message]()
- val defaultChunkSize = 65536 //32768 //16384
+ val defaultChunkSize = 65536
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
messages.synchronized{
- /*messages += message*/
+ /* messages += message */
messages.enqueue(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
@@ -221,8 +221,8 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
def getChunk(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
- /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
- /*val message = messages(nextMessageToBeUsed)*/
+ /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+ /* val message = messages(nextMessageToBeUsed) */
val message = messages.dequeue
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
@@ -262,7 +262,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
val currentBuffers = new ArrayBuffer[ByteBuffer]()
- /*channel.socket.setSendBufferSize(256 * 1024)*/
+ /* channel.socket.setSendBufferSize(256 * 1024) */
override def getRemoteAddress() = address
@@ -355,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
}
case None => {
// changeConnectionKeyInterest(0)
- /*key.interestOps(0)*/
+ /* key.interestOps(0) */
return false
}
}
@@ -540,10 +540,10 @@ private[spark] class ReceivingConnection(
return false
}
- /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
+ /* logDebug("Read " + bytesRead + " bytes for the buffer") */
if (currentChunk.buffer.remaining == 0) {
- /*println("Filled buffer at " + System.currentTimeMillis)*/
+ /* println("Filled buffer at " + System.currentTimeMillis) */
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index a75130cba2a2e..6b0a972f0bbe0 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -505,7 +505,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
}
}
handleMessageExecutor.execute(runnable)
- /*handleMessage(connection, message)*/
+ /* handleMessage(connection, message) */
}
private def handleClientAuthentication(
@@ -733,7 +733,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection())
- //send security message until going connection has been authenticated
+ // send security message until going connection has been authenticated
connection.send(message)
wakeupSelector()
@@ -859,14 +859,14 @@ private[spark] object ConnectionManager {
None
})
- /*testSequentialSending(manager)*/
- /*System.gc()*/
+ /* testSequentialSending(manager) */
+ /* System.gc() */
- /*testParallelSending(manager)*/
- /*System.gc()*/
+ /* testParallelSending(manager) */
+ /* System.gc() */
- /*testParallelDecreasingSending(manager)*/
- /*System.gc()*/
+ /* testParallelDecreasingSending(manager) */
+ /* System.gc() */
testContinuousSending(manager)
System.gc()
@@ -948,7 +948,7 @@ private[spark] object ConnectionManager {
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
- /*println("Started at " + startTime + ", finished at " + finishTime) */
+ /* println("Started at " + startTime + ", finished at " + finishTime) */
println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
index 35f64134b073a..9d9b9dbdd5331 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
@@ -47,8 +47,8 @@ private[spark] object ConnectionManagerTest extends Logging{
val slaves = slavesFile.mkString.split("\n")
slavesFile.close()
- /*println("Slaves")*/
- /*slaves.foreach(println)*/
+ /* println("Slaves") */
+ /* slaves.foreach(println) */
val tasknum = if (args.length > 2) args(2).toInt else slaves.length
val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
val count = if (args.length > 4) args(4).toInt else 3
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
index 3c09a713c6fe0..2b41c403b2e0a 100644
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -27,7 +27,7 @@ private[spark] object ReceiverTest {
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
- /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
+ /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */
val buffer = ByteBuffer.wrap("response".getBytes)
Some(Message.createBufferMessage(buffer, msg.id))
})
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index aac2c24a46faa..14c094c6177d5 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -50,7 +50,7 @@ private[spark] object SenderTest {
(0 until count).foreach(i => {
val dataMessage = Message.createBufferMessage(buffer.duplicate)
val startTime = System.currentTimeMillis
- /*println("Started timer at " + startTime)*/
+ /* println("Started timer at " + startTime) */
val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage)
.map { response =>
val buffer = response.asInstanceOf[BufferMessage].buffers(0)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
index f9082ffb9141a..4164e81d3a8ae 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -32,7 +32,7 @@ private[spark] class FileHeader (
buf.writeInt(fileLen)
buf.writeInt(blockId.name.length)
blockId.name.foreach((x: Char) => buf.writeByte(x))
- //padding the rest of header
+ // padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
} else {
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 8561711931047..9aa454a5c8b88 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -103,7 +103,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array
}
- override val partitioner = Some(part)
+ override val partitioner: Some[Partitioner] = Some(part)
override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = {
val sparkConf = SparkEnv.get.conf
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 932ff5bf369c7..3af008bd72378 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -171,7 +171,7 @@ class HadoopRDD[K, V](
array
}
- override def compute(theSplit: Partition, context: TaskContext) = {
+ override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 8df8718f3b65b..1b503743ac117 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -116,7 +116,7 @@ class JdbcRDD[T: ClassTag](
}
object JdbcRDD {
- def resultSetToObjectArray(rs: ResultSet) = {
+ def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index d1fff296878c3..461a749eac48b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -80,7 +80,7 @@ class NewHadoopRDD[K, V](
result
}
- override def compute(theSplit: Partition, context: TaskContext) = {
+ override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
val iter = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 4250a9d02f764..41ae0fec823e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -17,6 +17,9 @@
package org.apache.spark.rdd
+import java.io.File
+import java.io.FilenameFilter
+import java.io.IOException
import java.io.PrintWriter
import java.util.StringTokenizer
@@ -27,6 +30,7 @@ import scala.io.Source
import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.util.Utils
/**
@@ -38,7 +42,8 @@ class PipedRDD[T: ClassTag](
command: Seq[String],
envVars: Map[String, String],
printPipeContext: (String => Unit) => Unit,
- printRDDElement: (T, String => Unit) => Unit)
+ printRDDElement: (T, String => Unit) => Unit,
+ separateWorkingDir: Boolean)
extends RDD[String](prev) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
@@ -48,12 +53,24 @@ class PipedRDD[T: ClassTag](
command: String,
envVars: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null) =
- this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
+ printRDDElement: (T, String => Unit) => Unit = null,
+ separateWorkingDir: Boolean = false) =
+ this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement,
+ separateWorkingDir)
override def getPartitions: Array[Partition] = firstParent[T].partitions
+ /**
+ * A FilenameFilter that accepts anything that isn't equal to the name passed in.
+ * @param name of file or directory to leave out
+ */
+ class NotEqualsFileNameFilter(filterName: String) extends FilenameFilter {
+ def accept(dir: File, name: String): Boolean = {
+ !name.equals(filterName)
+ }
+ }
+
override def compute(split: Partition, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
@@ -67,6 +84,38 @@ class PipedRDD[T: ClassTag](
currentEnvVars.putAll(hadoopSplit.getPipeEnvVars())
}
+ // When spark.worker.separated.working.directory option is turned on, each
+ // task will be run in separate directory. This should be resolve file
+ // access conflict issue
+ val taskDirectory = "./tasks/" + java.util.UUID.randomUUID.toString
+ var workInTaskDirectory = false
+ logDebug("taskDirectory = " + taskDirectory)
+ if (separateWorkingDir == true) {
+ val currentDir = new File(".")
+ logDebug("currentDir = " + currentDir.getAbsolutePath())
+ val taskDirFile = new File(taskDirectory)
+ taskDirFile.mkdirs()
+
+ try {
+ val tasksDirFilter = new NotEqualsFileNameFilter("tasks")
+
+ // Need to add symlinks to jars, files, and directories. On Yarn we could have
+ // directories and other files not known to the SparkContext that were added via the
+ // Hadoop distributed cache. We also don't want to symlink to the /tasks directories we
+ // are creating here.
+ for (file <- currentDir.list(tasksDirFilter)) {
+ val fileWithDir = new File(currentDir, file)
+ Utils.symlink(new File(fileWithDir.getAbsolutePath()),
+ new File(taskDirectory + "/" + fileWithDir.getName()))
+ }
+ pb.directory(taskDirFile)
+ workInTaskDirectory = true
+ } catch {
+ case e: Exception => logError("Unable to setup task working directory: " + e.getMessage +
+ " (" + taskDirectory + ")")
+ }
+ }
+
val proc = pb.start()
val env = SparkEnv.get
@@ -112,6 +161,15 @@ class PipedRDD[T: ClassTag](
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}
+
+ // cleanup task working directory if used
+ if (workInTaskDirectory == true) {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ Utils.deleteRecursively(new File(taskDirectory))
+ }
+ logDebug("Removed task working directory " + taskDirectory)
+ }
+
false
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6af42248a5c3c..c43823bd769b7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -121,7 +121,7 @@ abstract class RDD[T: ClassTag](
@transient var name: String = null
/** Assign a name to this RDD */
- def setName(_name: String) = {
+ def setName(_name: String): RDD[T] = {
name = _name
this
}
@@ -481,16 +481,19 @@ abstract class RDD[T: ClassTag](
* instead of constructing a huge String to concat all the elements:
* def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
* for (e <- record._2){f(e)}
+ * @param separateWorkingDir Use separate working directories for each task.
* @return the result RDD
*/
def pipe(
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
+ printRDDElement: (T, String => Unit) => Unit = null,
+ separateWorkingDir: Boolean = false): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
- if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+ if (printRDDElement ne null) sc.clean(printRDDElement) else null,
+ separateWorkingDir)
}
/**
@@ -658,6 +661,18 @@ abstract class RDD[T: ClassTag](
Array.concat(results: _*)
}
+ /**
+ * Return an iterator that contains all of the elements in this RDD.
+ *
+ * The iterator will consume as much memory as the largest partition in this RDD.
+ */
+ def toLocalIterator: Iterator[T] = {
+ def collectPartition(p: Int): Array[T] = {
+ sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head
+ }
+ (0 until partitions.length).iterator.flatMap(i => collectPartition(i))
+ }
+
/**
* Return an array that contains all of the elements in this RDD.
*/
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 77c558ac46f6f..442a95bb2c44b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -84,9 +84,9 @@ class DAGScheduler(
private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]
private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage]
private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
- private[scheduler] val stageIdToActiveJob = new HashMap[Int, ActiveJob]
+ private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob]
- private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
+ private[scheduler] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
// Stages we need to run whose parents aren't done
private[scheduler] val waitingStages = new HashSet[Stage]
@@ -536,7 +536,7 @@ class DAGScheduler(
listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties))
runLocally(job)
} else {
- stageIdToActiveJob(jobId) = job
+ jobIdToActiveJob(jobId) = job
activeJobs += job
resultStageToJob(finalStage) = job
listenerBus.post(
@@ -559,7 +559,7 @@ class DAGScheduler(
// Cancel all running jobs.
runningStages.map(_.jobId).foreach(handleJobCancellation)
activeJobs.clear() // These should already be empty by this point,
- stageIdToActiveJob.clear() // but just in case we lost track of some jobs...
+ jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
case ExecutorAdded(execId, host) =>
handleExecutorAdded(execId, host)
@@ -569,7 +569,6 @@ class DAGScheduler(
case BeginEvent(task, taskInfo) =>
for (
- job <- stageIdToActiveJob.get(task.stageId);
stage <- stageIdToStage.get(task.stageId);
stageInfo <- stageToInfos.get(stage)
) {
@@ -697,7 +696,7 @@ class DAGScheduler(
private def activeJobForStage(stage: Stage): Option[Int] = {
if (stageIdToJobIds.contains(stage.id)) {
val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
- jobsThatUseStage.find(stageIdToActiveJob.contains)
+ jobsThatUseStage.find(jobIdToActiveJob.contains)
} else {
None
}
@@ -750,10 +749,10 @@ class DAGScheduler(
}
}
- val properties = if (stageIdToActiveJob.contains(jobId)) {
- stageIdToActiveJob(stage.jobId).properties
+ val properties = if (jobIdToActiveJob.contains(jobId)) {
+ jobIdToActiveJob(stage.jobId).properties
} else {
- //this stage will be assigned to "default" pool
+ // this stage will be assigned to "default" pool
null
}
@@ -827,7 +826,7 @@ class DAGScheduler(
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
- stageIdToActiveJob -= stage.jobId
+ jobIdToActiveJob -= stage.jobId
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
@@ -986,11 +985,11 @@ class DAGScheduler(
val independentStages = removeJobAndIndependentStages(jobId)
independentStages.foreach(taskScheduler.cancelTasks)
val error = new SparkException("Job %d cancelled".format(jobId))
- val job = stageIdToActiveJob(jobId)
+ val job = jobIdToActiveJob(jobId)
job.listener.jobFailed(error)
jobIdToStageIds -= jobId
activeJobs -= job
- stageIdToActiveJob -= jobId
+ jobIdToActiveJob -= jobId
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id)))
}
}
@@ -1011,7 +1010,7 @@ class DAGScheduler(
val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
jobIdToStageIdsRemove(job.jobId)
- stageIdToActiveJob -= resultStage.jobId
+ jobIdToActiveJob -= resultStage.jobId
activeJobs -= job
resultStageToJob -= resultStage
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, failedStage.id)))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 5555585c8b4cd..b3f2cb346f7da 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -164,8 +164,7 @@ object InputFormatInfo {
PS: I know the wording here is weird, hopefully it makes some sense !
*/
- def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]]
- = {
+ def computePreferredLocations(formats: Seq[InputFormatInfo]): Map[String, Set[SplitInfo]] = {
val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
for (inputSplit <- formats) {
@@ -178,6 +177,6 @@ object InputFormatInfo {
}
}
- nodeToSplit
+ nodeToSplit.mapValues(_.toSet).toMap
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 990e01a3e7959..7bfc30b4208a3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -172,7 +172,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
properties += ((key, value))
}
}
- //TODO (prashant) send conf instead of properties
+ // TODO (prashant) send conf instead of properties
driverActor = actorSystem.actorOf(
Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 6b6d814c1fe92..926e71573be32 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -107,7 +107,8 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
// DeserializationStream uses the EOF exception to indicate stopping condition.
- case _: KryoException => throw new EOFException
+ case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") =>
+ throw new EOFException
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index bcfc39146a61e..2fbbda5b76c74 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -284,7 +284,7 @@ object BlockFetcherIterator {
}
} catch {
case x: InterruptedException => logInfo("Copier Interrupted")
- //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ // case _ => throw new SparkException("Exception Throw in Shuffle Copier")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 71584b6eb102a..19138d9dde697 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -19,22 +19,20 @@ package org.apache.spark.storage
import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Random
-
import akka.actor.{ActorSystem, Cancellable, Props}
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
-
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.util._
+
sealed trait Values
case class ByteBufferValues(buffer: ByteBuffer) extends Values
@@ -59,6 +57,17 @@ private[spark] class BlockManager(
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
+ var tachyonInitialized = false
+ private[storage] lazy val tachyonStore: TachyonStore = {
+ val storeDir = conf.get("spark.tachyonStore.baseDir", "/tmp_spark_tachyon")
+ val appFolderName = conf.get("spark.tachyonStore.folderName")
+ val tachyonStorePath = s"${storeDir}/${appFolderName}/${this.executorId}"
+ val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998")
+ val tachyonBlockManager = new TachyonBlockManager(
+ shuffleBlockManager, tachyonStorePath, tachyonMaster)
+ tachyonInitialized = true
+ new TachyonStore(this, tachyonBlockManager)
+ }
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
@@ -248,8 +257,10 @@ private[spark] class BlockManager(
if (info.tellMaster) {
val storageLevel = status.storageLevel
val inMemSize = Math.max(status.memSize, droppedMemorySize)
+ val inTachyonSize = status.tachyonSize
val onDiskSize = status.diskSize
- master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize)
+ master.updateBlockInfo(
+ blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inTachyonSize)
} else true
}
@@ -259,22 +270,24 @@ private[spark] class BlockManager(
* and the updated in-memory and on-disk sizes.
*/
private def getCurrentBlockStatus(blockId: BlockId, info: BlockInfo): BlockStatus = {
- val (newLevel, inMemSize, onDiskSize) = info.synchronized {
+ val (newLevel, inMemSize, onDiskSize, inTachyonSize) = info.synchronized {
info.level match {
case null =>
- (StorageLevel.NONE, 0L, 0L)
+ (StorageLevel.NONE, 0L, 0L, 0L)
case level =>
val inMem = level.useMemory && memoryStore.contains(blockId)
+ val inTachyon = level.useOffHeap && tachyonStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
val deserialized = if (inMem) level.deserialized else false
- val replication = if (inMem || onDisk) level.replication else 1
- val storageLevel = StorageLevel(onDisk, inMem, deserialized, replication)
+ val replication = if (inMem || inTachyon || onDisk) level.replication else 1
+ val storageLevel = StorageLevel(onDisk, inMem, inTachyon, deserialized, replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
+ val tachyonSize = if (inTachyon) tachyonStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
- (storageLevel, memSize, diskSize)
+ (storageLevel, memSize, diskSize, tachyonSize)
}
}
- BlockStatus(newLevel, inMemSize, onDiskSize)
+ BlockStatus(newLevel, inMemSize, onDiskSize, inTachyonSize)
}
/**
@@ -354,6 +367,24 @@ private[spark] class BlockManager(
logDebug("Block " + blockId + " not found in memory")
}
}
+
+ // Look for the block in Tachyon
+ if (level.useOffHeap) {
+ logDebug("Getting block " + blockId + " from tachyon")
+ if (tachyonStore.contains(blockId)) {
+ tachyonStore.getBytes(blockId) match {
+ case Some(bytes) => {
+ if (!asValues) {
+ return Some(bytes)
+ } else {
+ return Some(dataDeserialize(blockId, bytes))
+ }
+ }
+ case None =>
+ logDebug("Block " + blockId + " not found in tachyon")
+ }
+ }
+ }
// Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
@@ -620,6 +651,23 @@ private[spark] class BlockManager(
}
// Keep track of which blocks are dropped from memory
res.droppedBlocks.foreach { block => updatedBlocks += block }
+ } else if (level.useOffHeap) {
+ // Save to Tachyon.
+ val res = data match {
+ case IteratorValues(iterator) =>
+ tachyonStore.putValues(blockId, iterator, level, false)
+ case ArrayBufferValues(array) =>
+ tachyonStore.putValues(blockId, array, level, false)
+ case ByteBufferValues(bytes) => {
+ bytes.rewind();
+ tachyonStore.putBytes(blockId, bytes, level)
+ }
+ }
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
} else {
// Save directly to disk.
// Don't get back the bytes unless we replicate them.
@@ -644,8 +692,8 @@ private[spark] class BlockManager(
val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo)
if (putBlockStatus.storageLevel != StorageLevel.NONE) {
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
+ // Now that the block is in either the memory, tachyon, or disk store,
+ // let other threads read it, and tell the master about it.
marked = true
putBlockInfo.markReady(size)
if (tellMaster) {
@@ -707,7 +755,8 @@ private[spark] class BlockManager(
*/
var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) {
- val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ val tLevel = StorageLevel(
+ level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
}
@@ -832,9 +881,10 @@ private[spark] class BlockManager(
// Removals are idempotent in disk store and memory store. At worst, we get a warning.
val removedFromMemory = memoryStore.remove(blockId)
val removedFromDisk = diskStore.remove(blockId)
- if (!removedFromMemory && !removedFromDisk) {
+ val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false
+ if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) {
logWarning("Block " + blockId + " could not be removed as it was not found in either " +
- "the disk or memory store")
+ "the disk, memory, or tachyon store")
}
blockInfo.remove(blockId)
if (tellMaster && info.tellMaster) {
@@ -871,6 +921,9 @@ private[spark] class BlockManager(
if (level.useDisk) {
diskStore.remove(id)
}
+ if (level.useOffHeap) {
+ tachyonStore.remove(id)
+ }
iterator.remove()
logInfo("Dropped block " + id)
}
@@ -946,6 +999,9 @@ private[spark] class BlockManager(
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
+ if (tachyonInitialized) {
+ tachyonStore.clear()
+ }
metadataCleaner.cancel()
broadcastCleaner.cancel()
logInfo("BlockManager stopped")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index ed6937851b836..4bc1b407ad106 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -63,9 +63,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long): Boolean = {
+ diskSize: Long,
+ tachyonSize: Long): Boolean = {
val res = askDriverWithReply[Boolean](
- UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
+ UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize))
logInfo("Updated info of block " + blockId)
res
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index ff2652b640272..378f4cadc17d7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -73,10 +73,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
register(blockManagerId, maxMemSize, slaveActor)
sender ! true
- case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ case UpdateBlockInfo(
+ blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) =>
// TODO: Ideally we want to handle all the message replies in receive instead of in the
// individual private methods.
- updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)
+ updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)
case GetLocations(blockId) =>
sender ! getLocations(blockId)
@@ -246,7 +247,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long) {
+ diskSize: Long,
+ tachyonSize: Long) {
if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.executorId == "" && !isLocal) {
@@ -265,7 +267,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
return
}
- blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
+ blockManagerInfo(blockManagerId).updateBlockInfo(
+ blockId, storageLevel, memSize, diskSize, tachyonSize)
var locations: mutable.HashSet[BlockManagerId] = null
if (blockLocations.containsKey(blockId)) {
@@ -309,8 +312,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}
}
-
-private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+private[spark] case class BlockStatus(
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long,
+ tachyonSize: Long)
private[spark] class BlockManagerInfo(
val blockManagerId: BlockManagerId,
@@ -336,7 +342,8 @@ private[spark] class BlockManagerInfo(
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long) {
+ diskSize: Long,
+ tachyonSize: Long) {
updateLastSeenMs()
@@ -350,23 +357,29 @@ private[spark] class BlockManagerInfo(
}
if (storageLevel.isValid) {
- /* isValid means it is either stored in-memory or on-disk.
+ /* isValid means it is either stored in-memory, on-disk or on-Tachyon.
* But the memSize here indicates the data size in or dropped from memory,
+ * tachyonSize here indicates the data size in or dropped from Tachyon,
* and the diskSize here indicates the data size in or dropped to disk.
* They can be both larger than 0, when a block is dropped from memory to disk.
* Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */
if (storageLevel.useMemory) {
- _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0))
+ _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0))
_remainingMem -= memSize
logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
Utils.bytesToString(_remainingMem)))
}
if (storageLevel.useDisk) {
- _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize))
+ _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0))
logInfo("Added %s on disk on %s (size: %s)".format(
blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
}
+ if (storageLevel.useOffHeap) {
+ _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize))
+ logInfo("Added %s on tachyon on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize)))
+ }
} else if (_blocks.containsKey(blockId)) {
// If isValid is not true, drop the block.
val blockStatus: BlockStatus = _blocks.get(blockId)
@@ -381,6 +394,10 @@ private[spark] class BlockManagerInfo(
logInfo("Removed %s on %s on disk (size: %s)".format(
blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
}
+ if (blockStatus.storageLevel.useOffHeap) {
+ logInfo("Removed %s on %s on tachyon (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize)))
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index bbb9529b5a0ca..8a36b5cc42dfd 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -53,11 +53,12 @@ private[storage] object BlockManagerMessages {
var blockId: BlockId,
var storageLevel: StorageLevel,
var memSize: Long,
- var diskSize: Long)
+ var diskSize: Long,
+ var tachyonSize: Long)
extends ToBlockManagerMaster
with Externalizable {
- def this() = this(null, null, null, 0, 0) // For deserialization only
+ def this() = this(null, null, null, 0, 0, 0) // For deserialization only
override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out)
@@ -65,6 +66,7 @@ private[storage] object BlockManagerMessages {
storageLevel.writeExternal(out)
out.writeLong(memSize)
out.writeLong(diskSize)
+ out.writeLong(tachyonSize)
}
override def readExternal(in: ObjectInput) {
@@ -73,6 +75,7 @@ private[storage] object BlockManagerMessages {
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()
+ tachyonSize = in.readLong()
}
}
@@ -81,13 +84,15 @@ private[storage] object BlockManagerMessages {
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long): UpdateBlockInfo = {
- new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)
+ diskSize: Long,
+ tachyonSize: Long): UpdateBlockInfo = {
+ new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)
}
// For pattern-matching
- def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = {
- Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
+ def unapply(h: UpdateBlockInfo)
+ : Option[(BlockManagerId, BlockId, StorageLevel, Long, Long, Long)] = {
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize, h.tachyonSize))
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 1b7934d59fa1d..95e71de2d3f1d 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -21,8 +21,9 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
/**
* Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
- * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
- * in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
+ * or Tachyon, whether to drop the RDD to disk if it falls out of memory or Tachyon , whether to
+ * keep the data in memory in a serialized format, and whether to replicate the RDD partitions on
+ * multiple nodes.
* The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants
* for commonly useful storage levels. To create your own storage level object, use the
* factory method of the singleton object (`StorageLevel(...)`).
@@ -30,45 +31,58 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
class StorageLevel private(
private var useDisk_ : Boolean,
private var useMemory_ : Boolean,
+ private var useOffHeap_ : Boolean,
private var deserialized_ : Boolean,
private var replication_ : Int = 1)
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
private def this(flags: Int, replication: Int) {
- this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
+ this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
- def this() = this(false, true, false) // For deserialization
+ def this() = this(false, true, false, false) // For deserialization
def useDisk = useDisk_
def useMemory = useMemory_
+ def useOffHeap = useOffHeap_
def deserialized = deserialized_
def replication = replication_
assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
+ if (useOffHeap) {
+ require(useDisk == false, "Off-heap storage level does not support using disk")
+ require(useMemory == false, "Off-heap storage level does not support using heap memory")
+ require(deserialized == false, "Off-heap storage level does not support deserialized storage")
+ require(replication == 1, "Off-heap storage level does not support multiple replication")
+ }
+
override def clone(): StorageLevel = new StorageLevel(
- this.useDisk, this.useMemory, this.deserialized, this.replication)
+ this.useDisk, this.useMemory, this.useOffHeap, this.deserialized, this.replication)
override def equals(other: Any): Boolean = other match {
case s: StorageLevel =>
s.useDisk == useDisk &&
s.useMemory == useMemory &&
+ s.useOffHeap == useOffHeap &&
s.deserialized == deserialized &&
s.replication == replication
case _ =>
false
}
- def isValid = ((useMemory || useDisk) && (replication > 0))
+ def isValid = ((useMemory || useDisk || useOffHeap) && (replication > 0))
def toInt: Int = {
var ret = 0
if (useDisk_) {
- ret |= 4
+ ret |= 8
}
if (useMemory_) {
+ ret |= 4
+ }
+ if (useOffHeap_) {
ret |= 2
}
if (deserialized_) {
@@ -84,8 +98,9 @@ class StorageLevel private(
override def readExternal(in: ObjectInput) {
val flags = in.readByte()
- useDisk_ = (flags & 4) != 0
- useMemory_ = (flags & 2) != 0
+ useDisk_ = (flags & 8) != 0
+ useMemory_ = (flags & 4) != 0
+ useOffHeap_ = (flags & 2) != 0
deserialized_ = (flags & 1) != 0
replication_ = in.readByte()
}
@@ -93,14 +108,15 @@ class StorageLevel private(
@throws(classOf[IOException])
private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this)
- override def toString: String =
- "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+ override def toString: String = "StorageLevel(%b, %b, %b, %b, %d)".format(
+ useDisk, useMemory, useOffHeap, deserialized, replication)
override def hashCode(): Int = toInt * 41 + replication
def description : String = {
var result = ""
result += (if (useDisk) "Disk " else "")
result += (if (useMemory) "Memory " else "")
+ result += (if (useOffHeap) "Tachyon " else "")
result += (if (deserialized) "Deserialized " else "Serialized ")
result += "%sx Replicated".format(replication)
result
@@ -113,28 +129,35 @@ class StorageLevel private(
* new storage levels.
*/
object StorageLevel {
- val NONE = new StorageLevel(false, false, false)
- val DISK_ONLY = new StorageLevel(true, false, false)
- val DISK_ONLY_2 = new StorageLevel(true, false, false, 2)
- val MEMORY_ONLY = new StorageLevel(false, true, true)
- val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2)
- val MEMORY_ONLY_SER = new StorageLevel(false, true, false)
- val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2)
- val MEMORY_AND_DISK = new StorageLevel(true, true, true)
- val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
- val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
- val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
+ val NONE = new StorageLevel(false, false, false, false)
+ val DISK_ONLY = new StorageLevel(true, false, false, false)
+ val DISK_ONLY_2 = new StorageLevel(true, false, false, false, 2)
+ val MEMORY_ONLY = new StorageLevel(false, true, false, true)
+ val MEMORY_ONLY_2 = new StorageLevel(false, true, false, true, 2)
+ val MEMORY_ONLY_SER = new StorageLevel(false, true, false, false)
+ val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, false, 2)
+ val MEMORY_AND_DISK = new StorageLevel(true, true, false, true)
+ val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2)
+ val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false)
+ val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2)
+ val OFF_HEAP = new StorageLevel(false, false, true, false)
+
+ /** Create a new StorageLevel object without setting useOffHeap */
+ def apply(useDisk: Boolean, useMemory: Boolean, useOffHeap: Boolean,
+ deserialized: Boolean, replication: Int) = getCachedStorageLevel(
+ new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication))
/** Create a new StorageLevel object */
- def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
- getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
+ def apply(useDisk: Boolean, useMemory: Boolean,
+ deserialized: Boolean, replication: Int = 1) = getCachedStorageLevel(
+ new StorageLevel(useDisk, useMemory, false, deserialized, replication))
/** Create a new StorageLevel object from its integer representation */
- def apply(flags: Int, replication: Int) =
+ def apply(flags: Int, replication: Int): StorageLevel =
getCachedStorageLevel(new StorageLevel(flags, replication))
/** Read StorageLevel object from ObjectInput stream */
- def apply(in: ObjectInput) = {
+ def apply(in: ObjectInput): StorageLevel = {
val obj = new StorageLevel()
obj.readExternal(in)
getCachedStorageLevel(obj)
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index 26565f56ad858..7a174959037be 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -44,7 +44,7 @@ private[spark] class StorageStatusListener extends SparkListener {
storageStatusList.foreach { storageStatus =>
val unpersistedBlocksIds = storageStatus.rddBlocks.keys.filter(_.rddId == unpersistedRDDId)
unpersistedBlocksIds.foreach { blockId =>
- storageStatus.blocks(blockId) = BlockStatus(StorageLevel.NONE, 0L, 0L)
+ storageStatus.blocks(blockId) = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 6153dfe0b7e13..ff6e84cf9819a 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -48,17 +48,23 @@ class StorageStatus(
}
private[spark]
-class RDDInfo(val id: Int, val name: String, val numPartitions: Int, val storageLevel: StorageLevel)
- extends Ordered[RDDInfo] {
+class RDDInfo(
+ val id: Int,
+ val name: String,
+ val numPartitions: Int,
+ val storageLevel: StorageLevel) extends Ordered[RDDInfo] {
var numCachedPartitions = 0
var memSize = 0L
var diskSize = 0L
+ var tachyonSize= 0L
override def toString = {
- ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " +
- "DiskSize: %s").format(name, id, storageLevel.toString, numCachedPartitions,
- numPartitions, Utils.bytesToString(memSize), Utils.bytesToString(diskSize))
+ import Utils.bytesToString
+ ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s;" +
+ "TachyonSize: %s; DiskSize: %s").format(
+ name, id, storageLevel.toString, numCachedPartitions, numPartitions,
+ bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize))
}
override def compare(that: RDDInfo) = {
@@ -105,14 +111,17 @@ object StorageUtils {
val rddInfoMap = rddInfos.map { info => (info.id, info) }.toMap
val rddStorageInfos = blockStatusMap.flatMap { case (rddId, blocks) =>
- // Add up memory and disk sizes
- val persistedBlocks = blocks.filter { status => status.memSize + status.diskSize > 0 }
+ // Add up memory, disk and Tachyon sizes
+ val persistedBlocks =
+ blocks.filter { status => status.memSize + status.diskSize + status.tachyonSize > 0 }
val memSize = persistedBlocks.map(_.memSize).reduceOption(_ + _).getOrElse(0L)
val diskSize = persistedBlocks.map(_.diskSize).reduceOption(_ + _).getOrElse(0L)
+ val tachyonSize = persistedBlocks.map(_.tachyonSize).reduceOption(_ + _).getOrElse(0L)
rddInfoMap.get(rddId).map { rddInfo =>
rddInfo.numCachedPartitions = persistedBlocks.length
rddInfo.memSize = memSize
rddInfo.diskSize = diskSize
+ rddInfo.tachyonSize = tachyonSize
rddInfo
}
}.toArray
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
new file mode 100644
index 0000000000000..b0b9674856568
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+
+import tachyon.client.TachyonFS
+import tachyon.client.TachyonFile
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.ShuffleSender
+import org.apache.spark.util.Utils
+
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By
+ * default, one block is mapped to one file with a name given by its BlockId.
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class TachyonBlockManager(
+ shuffleManager: ShuffleBlockManager,
+ rootDirs: String,
+ val master: String)
+ extends Logging {
+
+ val client = if (master != null && master != "") TachyonFS.get(master) else null
+
+ if (client == null) {
+ logError("Failed to connect to the Tachyon as the master address is not configured")
+ System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_INITIALIZE)
+ }
+
+ private val MAX_DIR_CREATION_ATTEMPTS = 10
+ private val subDirsPerTachyonDir =
+ shuffleManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt
+
+ // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName;
+ // then, inside this directory, create multiple subdirectories that we will hash files into,
+ // in order to avoid having really large inodes at the top level in Tachyon.
+ private val tachyonDirs: Array[TachyonFile] = createTachyonDirs()
+ private val subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir))
+
+ addShutdownHook()
+
+ def removeFile(file: TachyonFile): Boolean = {
+ client.delete(file.getPath(), false)
+ }
+
+ def fileExists(file: TachyonFile): Boolean = {
+ client.exist(file.getPath())
+ }
+
+ def getFile(filename: String): TachyonFile = {
+ // Figure out which tachyon directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % tachyonDirs.length
+ val subDirId = (hash / tachyonDirs.length) % subDirsPerTachyonDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId)
+ client.mkdir(path)
+ val newDir = client.getFile(path)
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+ val filePath = subDir + "/" + filename
+ if(!client.exist(filePath)) {
+ client.createFile(filePath)
+ }
+ val file = client.getFile(filePath)
+ file
+ }
+
+ def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name)
+
+ // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore.
+ private def createTachyonDirs(): Array[TachyonFile] = {
+ logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var tachyonDir: TachyonFile = null
+ var tachyonDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId
+ if (!client.exist(path)) {
+ foundLocalDir = client.mkdir(path)
+ tachyonDir = client.getFile(path)
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create tachyon dir in " +
+ rootDir)
+ System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created tachyon directory at " + tachyonDir)
+ tachyonDir
+ }
+ }
+
+ private def addShutdownHook() {
+ tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark tachyon dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ tachyonDirs.foreach { tachyonDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) {
+ Utils.deleteRecursively(tachyonDir, client)
+ }
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting tachyon spark dir: " + tachyonDir, t)
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala
new file mode 100644
index 0000000000000..b86abbda1d3e7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import tachyon.client.TachyonFile
+
+/**
+ * References a particular segment of a file (potentially the entire file), based off an offset and
+ * a length.
+ */
+private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
new file mode 100644
index 0000000000000..c37e76f893605
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.IOException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import tachyon.client.{WriteType, ReadType}
+
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
+import org.apache.spark.serializer.Serializer
+
+
+private class Entry(val size: Long)
+
+
+/**
+ * Stores BlockManager blocks on Tachyon.
+ */
+private class TachyonStore(
+ blockManager: BlockManager,
+ tachyonManager: TachyonBlockManager)
+ extends BlockStore(blockManager: BlockManager) with Logging {
+
+ logInfo("TachyonStore started")
+
+ override def getSize(blockId: BlockId): Long = {
+ tachyonManager.getFile(blockId.name).length
+ }
+
+ override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = {
+ putToTachyonStore(blockId, bytes, true)
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: ArrayBuffer[Any],
+ level: StorageLevel,
+ returnValues: Boolean): PutResult = {
+ return putValues(blockId, values.toIterator, level, returnValues)
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean): PutResult = {
+ logDebug("Attempting to write values for block " + blockId)
+ val _bytes = blockManager.dataSerialize(blockId, values)
+ putToTachyonStore(blockId, _bytes, returnValues)
+ }
+
+ private def putToTachyonStore(
+ blockId: BlockId,
+ bytes: ByteBuffer,
+ returnValues: Boolean): PutResult = {
+ // So that we do not modify the input offsets !
+ // duplicate does not copy buffer, so inexpensive
+ val byteBuffer = bytes.duplicate()
+ byteBuffer.rewind()
+ logDebug("Attempting to put block " + blockId + " into Tachyon")
+ val startTime = System.currentTimeMillis
+ val file = tachyonManager.getFile(blockId)
+ val os = file.getOutStream(WriteType.TRY_CACHE)
+ os.write(byteBuffer.array())
+ os.close()
+ val finishTime = System.currentTimeMillis
+ logDebug("Block %s stored as %s file in Tachyon in %d ms".format(
+ blockId, Utils.bytesToString(byteBuffer.limit), (finishTime - startTime)))
+
+ if (returnValues) {
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
+ } else {
+ PutResult(bytes.limit(), null)
+ }
+ }
+
+ override def remove(blockId: BlockId): Boolean = {
+ val file = tachyonManager.getFile(blockId)
+ if (tachyonManager.fileExists(file)) {
+ tachyonManager.removeFile(file)
+ } else {
+ false
+ }
+ }
+
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
+ }
+
+
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ val file = tachyonManager.getFile(blockId)
+ if (file == null || file.getLocationHosts().size == 0) {
+ return None
+ }
+ val is = file.getInStream(ReadType.CACHE)
+ var buffer: ByteBuffer = null
+ try {
+ if (is != null) {
+ val size = file.length
+ val bs = new Array[Byte](size.asInstanceOf[Int])
+ val fetchSize = is.read(bs, 0, size.asInstanceOf[Int])
+ buffer = ByteBuffer.wrap(bs)
+ if (fetchSize != size) {
+ logWarning("Failed to fetch the block " + blockId + " from Tachyon : Size " + size +
+ " is not equal to fetched size " + fetchSize)
+ return None
+ }
+ }
+ } catch {
+ case ioe: IOException => {
+ logWarning("Failed to fetch the block " + blockId + " from Tachyon", ioe)
+ return None
+ }
+ }
+ Some(buffer)
+ }
+
+ override def contains(blockId: BlockId): Boolean = {
+ val file = tachyonManager.getFile(blockId)
+ tachyonManager.fileExists(file)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 6e1736f6fbc23..e1a1f209c9282 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -18,13 +18,14 @@
package org.apache.spark.ui
import java.net.{InetSocketAddress, URL}
+import javax.servlet.DispatcherType
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}
import scala.xml.Node
-import org.eclipse.jetty.server.{DispatcherType, Server}
+import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.handler._
import org.eclipse.jetty.servlet._
import org.eclipse.jetty.util.thread.QueuedThreadPool
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index fd638c83aac6e..ef1ad872c8ef7 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -47,7 +47,8 @@ private[spark] class SparkUI(
val securityManager = if (live) sc.env.securityManager else new SecurityManager(conf)
- private val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName())
+ private val bindHost = Utils.localHostName()
+ private val publicHost = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(bindHost)
private val port = conf.get("spark.ui.port", SparkUI.DEFAULT_PORT).toInt
private var serverInfo: Option[ServerInfo] = None
@@ -79,8 +80,8 @@ private[spark] class SparkUI(
/** Bind the HTTP server which backs this web interface */
def bind() {
try {
- serverInfo = Some(startJettyServer(host, port, handlers, sc.conf))
- logInfo("Started Spark Web UI at http://%s:%d".format(host, boundPort))
+ serverInfo = Some(startJettyServer(bindHost, port, handlers, sc.conf))
+ logInfo("Started Spark Web UI at http://%s:%d".format(publicHost, boundPort))
} catch {
case e: Exception =>
logError("Failed to create Spark JettyUtils", e)
@@ -111,7 +112,7 @@ private[spark] class SparkUI(
logInfo("Stopped Spark Web UI at %s".format(appUIAddress))
}
- private[spark] def appUIAddress = "http://" + host + ":" + boundPort
+ private[spark] def appUIAddress = "http://" + publicHost + ":" + boundPort
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
index f3c93d4214ad0..70d62b66a4829 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -25,7 +25,7 @@ import org.apache.spark.scheduler.Schedulable
import org.apache.spark.ui.Page._
import org.apache.spark.ui.UIUtils
-/** Page showing list of all ongoing and recently finished stages and pools*/
+/** Page showing list of all ongoing and recently finished stages and pools */
private[ui] class IndexPage(parent: JobProgressUI) {
private val appName = parent.appName
private val basePath = parent.basePath
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index d10aa12b9ebca..048f671c8788f 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -81,9 +81,8 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener {
/** If stages is too large, remove and garbage collect old stages */
private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > retainedStages) {
- val toRemove = retainedStages / 10
- stages.takeRight(toRemove).foreach( s => {
- stageIdToTaskData.remove(s.stageId)
+ val toRemove = math.max(retainedStages / 10, 1)
+ stages.take(toRemove).foreach { s =>
stageIdToTime.remove(s.stageId)
stageIdToShuffleRead.remove(s.stageId)
stageIdToShuffleWrite.remove(s.stageId)
@@ -92,10 +91,12 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener {
stageIdToTasksActive.remove(s.stageId)
stageIdToTasksComplete.remove(s.stageId)
stageIdToTasksFailed.remove(s.stageId)
+ stageIdToTaskData.remove(s.stageId)
+ stageIdToExecutorSummaries.remove(s.stageId)
stageIdToPool.remove(s.stageId)
- if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
- })
- stages.trimEnd(toRemove)
+ stageIdToDescription.remove(s.stageId)
+ }
+ stages.trimStart(toRemove)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
index 4d8b01dbe6e1b..a7b24ff695214 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -84,7 +84,7 @@ private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListe
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized {
val rddInfo = stageSubmitted.stageInfo.rddInfo
- _rddInfoMap(rddInfo.id) = rddInfo
+ _rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo)
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala
index b2732de51058a..0fa461e5e9d27 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala
@@ -33,6 +33,7 @@ private[ui] class IndexPage(parent: BlockManagerUI) {
private lazy val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
+
val rdds = listener.rddInfoList
val content = UIUtils.listingTable(rddHeader, rddRow, rdds)
UIUtils.headerSparkPage(content, basePath, appName, "Storage ", Storage)
@@ -45,6 +46,7 @@ private[ui] class IndexPage(parent: BlockManagerUI) {
"Cached Partitions",
"Fraction Cached",
"Size in Memory",
+ "Size in Tachyon",
"Size on Disk")
/** Render an HTML row representing an RDD */
@@ -60,6 +62,7 @@ private[ui] class IndexPage(parent: BlockManagerUI) {
{rdd.numCachedPartitions} |
{"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} |
{Utils.bytesToString(rdd.memSize)} |
+ {Utils.bytesToString(rdd.tachyonSize)} |
{Utils.bytesToString(rdd.diskSize)} |
}
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index a8d20ee332355..cdbbc65292188 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -112,7 +112,7 @@ private[spark] object ClosureCleaner extends Logging {
accessedFields(cls) = Set[String]()
for (cls <- func.getClass :: innerClasses)
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
- //logInfo("accessedFields: " + accessedFields)
+ // logInfo("accessedFields: " + accessedFields)
val inInterpreter = {
try {
@@ -139,13 +139,13 @@ private[spark] object ClosureCleaner extends Logging {
val field = cls.getDeclaredField(fieldName)
field.setAccessible(true)
val value = field.get(obj)
- //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
+ // logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
field.set(outer, value)
}
}
if (outer != null) {
- //logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
+ // logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
val field = func.getClass.getDeclaredField("$outer")
field.setAccessible(true)
field.set(func, outer)
@@ -153,7 +153,7 @@ private[spark] object ClosureCleaner extends Logging {
}
private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
- //logInfo("Creating a " + cls + " with outer = " + outer)
+ // logInfo("Creating a " + cls + " with outer = " + outer)
if (!inInterpreter) {
// This is a bona fide closure class, whose constructor has no effects
// other than to set its fields, so use its constructor
@@ -170,7 +170,7 @@ private[spark] object ClosureCleaner extends Logging {
val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
val obj = newCtor.newInstance().asInstanceOf[AnyRef]
if (outer != null) {
- //logInfo("3: Setting $outer on " + cls + " to " + outer);
+ // logInfo("3: Setting $outer on " + cls + " to " + outer);
val field = cls.getDeclaredField("$outer")
field.setAccessible(true)
field.set(obj, outer)
diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala
index ab738c4b868fa..5b347555fe708 100644
--- a/core/src/main/scala/org/apache/spark/util/Distribution.scala
+++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala
@@ -19,6 +19,8 @@ package org.apache.spark.util
import java.io.PrintStream
+import scala.collection.immutable.IndexedSeq
+
/**
* Util for getting some stats from a small sample of numeric values, with some handy
* summary functions.
@@ -40,7 +42,8 @@ class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int)
* given from 0 to 1
* @param probabilities
*/
- def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = {
+ def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities)
+ : IndexedSeq[Double] = {
probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))}
}
@@ -48,7 +51,7 @@ class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int)
math.min((p * length).toInt + startIdx, endIdx - 1)
}
- def showQuantiles(out: PrintStream = System.out) = {
+ def showQuantiles(out: PrintStream = System.out): Unit = {
out.println("min\t25%\t50%\t75%\tmax")
getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")}
out.println
diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
index c539d2f708f95..4188a869c13da 100644
--- a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
+++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
@@ -49,7 +49,7 @@ private[akka] class IndestructibleActorSystemImpl(
if (isFatalError(cause) && !settings.JvmExitOnFatalError) {
log.error(cause, "Uncaught fatal error from thread [{}] not shutting down " +
"ActorSystem [{}] tolerating and continuing.... ", thread.getName, name)
- //shutdown() //TODO make it configurable
+ // shutdown() //TODO make it configurable
} else {
fallbackHandler.uncaughtException(thread, cause)
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 346f2b7856791..2155a8888c85c 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -195,7 +195,7 @@ private[spark] object JsonProtocol {
taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
val updatedBlocks = taskMetrics.updatedBlocks.map { blocks =>
JArray(blocks.toList.map { case (id, status) =>
- ("Block ID" -> blockIdToJson(id)) ~
+ ("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
})
}.getOrElse(JNothing)
@@ -274,49 +274,23 @@ private[spark] object JsonProtocol {
("Number of Partitions" -> rddInfo.numPartitions) ~
("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~
("Memory Size" -> rddInfo.memSize) ~
+ ("Tachyon Size" -> rddInfo.tachyonSize) ~
("Disk Size" -> rddInfo.diskSize)
}
def storageLevelToJson(storageLevel: StorageLevel): JValue = {
("Use Disk" -> storageLevel.useDisk) ~
("Use Memory" -> storageLevel.useMemory) ~
+ ("Use Tachyon" -> storageLevel.useOffHeap) ~
("Deserialized" -> storageLevel.deserialized) ~
("Replication" -> storageLevel.replication)
}
- def blockIdToJson(blockId: BlockId): JValue = {
- val blockType = Utils.getFormattedClassName(blockId)
- val json: JObject = blockId match {
- case rddBlockId: RDDBlockId =>
- ("RDD ID" -> rddBlockId.rddId) ~
- ("Split Index" -> rddBlockId.splitIndex)
- case shuffleBlockId: ShuffleBlockId =>
- ("Shuffle ID" -> shuffleBlockId.shuffleId) ~
- ("Map ID" -> shuffleBlockId.mapId) ~
- ("Reduce ID" -> shuffleBlockId.reduceId)
- case broadcastBlockId: BroadcastBlockId =>
- "Broadcast ID" -> broadcastBlockId.broadcastId
- case broadcastHelperBlockId: BroadcastHelperBlockId =>
- ("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~
- ("Helper Type" -> broadcastHelperBlockId.hType)
- case taskResultBlockId: TaskResultBlockId =>
- "Task ID" -> taskResultBlockId.taskId
- case streamBlockId: StreamBlockId =>
- ("Stream ID" -> streamBlockId.streamId) ~
- ("Unique ID" -> streamBlockId.uniqueId)
- case tempBlockId: TempBlockId =>
- val uuid = UUIDToJson(tempBlockId.id)
- "Temp ID" -> uuid
- case testBlockId: TestBlockId =>
- "Test ID" -> testBlockId.id
- }
- ("Type" -> blockType) ~ json
- }
-
def blockStatusToJson(blockStatus: BlockStatus): JValue = {
val storageLevel = storageLevelToJson(blockStatus.storageLevel)
("Storage Level" -> storageLevel) ~
("Memory Size" -> blockStatus.memSize) ~
+ ("Tachyon Size" -> blockStatus.tachyonSize) ~
("Disk Size" -> blockStatus.diskSize)
}
@@ -513,7 +487,7 @@ private[spark] object JsonProtocol {
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value =>
value.extract[List[JValue]].map { block =>
- val id = blockIdFromJson(block \ "Block ID")
+ val id = BlockId((block \ "Block ID").extract[String])
val status = blockStatusFromJson(block \ "Status")
(id, status)
}
@@ -599,11 +573,13 @@ private[spark] object JsonProtocol {
val numPartitions = (json \ "Number of Partitions").extract[Int]
val numCachedPartitions = (json \ "Number of Cached Partitions").extract[Int]
val memSize = (json \ "Memory Size").extract[Long]
+ val tachyonSize = (json \ "Tachyon Size").extract[Long]
val diskSize = (json \ "Disk Size").extract[Long]
val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel)
rddInfo.numCachedPartitions = numCachedPartitions
rddInfo.memSize = memSize
+ rddInfo.tachyonSize = tachyonSize
rddInfo.diskSize = diskSize
rddInfo
}
@@ -611,60 +587,18 @@ private[spark] object JsonProtocol {
def storageLevelFromJson(json: JValue): StorageLevel = {
val useDisk = (json \ "Use Disk").extract[Boolean]
val useMemory = (json \ "Use Memory").extract[Boolean]
+ val useTachyon = (json \ "Use Tachyon").extract[Boolean]
val deserialized = (json \ "Deserialized").extract[Boolean]
val replication = (json \ "Replication").extract[Int]
- StorageLevel(useDisk, useMemory, deserialized, replication)
- }
-
- def blockIdFromJson(json: JValue): BlockId = {
- val rddBlockId = Utils.getFormattedClassName(RDDBlockId)
- val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId)
- val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId)
- val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId)
- val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId)
- val streamBlockId = Utils.getFormattedClassName(StreamBlockId)
- val tempBlockId = Utils.getFormattedClassName(TempBlockId)
- val testBlockId = Utils.getFormattedClassName(TestBlockId)
-
- (json \ "Type").extract[String] match {
- case `rddBlockId` =>
- val rddId = (json \ "RDD ID").extract[Int]
- val splitIndex = (json \ "Split Index").extract[Int]
- new RDDBlockId(rddId, splitIndex)
- case `shuffleBlockId` =>
- val shuffleId = (json \ "Shuffle ID").extract[Int]
- val mapId = (json \ "Map ID").extract[Int]
- val reduceId = (json \ "Reduce ID").extract[Int]
- new ShuffleBlockId(shuffleId, mapId, reduceId)
- case `broadcastBlockId` =>
- val broadcastId = (json \ "Broadcast ID").extract[Long]
- new BroadcastBlockId(broadcastId)
- case `broadcastHelperBlockId` =>
- val broadcastBlockId =
- blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId]
- val hType = (json \ "Helper Type").extract[String]
- new BroadcastHelperBlockId(broadcastBlockId, hType)
- case `taskResultBlockId` =>
- val taskId = (json \ "Task ID").extract[Long]
- new TaskResultBlockId(taskId)
- case `streamBlockId` =>
- val streamId = (json \ "Stream ID").extract[Int]
- val uniqueId = (json \ "Unique ID").extract[Long]
- new StreamBlockId(streamId, uniqueId)
- case `tempBlockId` =>
- val tempId = UUIDFromJson(json \ "Temp ID")
- new TempBlockId(tempId)
- case `testBlockId` =>
- val testId = (json \ "Test ID").extract[String]
- new TestBlockId(testId)
- }
+ StorageLevel(useDisk, useMemory, useTachyon, deserialized, replication)
}
def blockStatusFromJson(json: JValue): BlockStatus = {
val storageLevel = storageLevelFromJson(json \ "Storage Level")
val memorySize = (json \ "Memory Size").extract[Long]
val diskSize = (json \ "Disk Size").extract[Long]
- BlockStatus(storageLevel, memorySize, diskSize)
+ val tachyonSize = (json \ "Tachyon Size").extract[Long]
+ BlockStatus(storageLevel, memorySize, diskSize, tachyonSize)
}
diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
index 2c1a6f8fd0a44..a6b39247a54ca 100644
--- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala
+++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
@@ -24,8 +24,8 @@ package org.apache.spark.util
* @param _1 Element 1 of this MutablePair
* @param _2 Element 2 of this MutablePair
*/
-case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1,
- @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
+case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T1,
+ @specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T2]
(var _1: T1, var _2: T2)
extends Product2[T1, T2]
{
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 62ee704d580c2..4435b21a7505e 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -26,6 +26,7 @@ import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.SortedSet
import scala.io.Source
import scala.reflect.ClassTag
@@ -33,16 +34,20 @@ import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.json4s._
+import tachyon.client.{TachyonFile,TachyonFS}
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
+
/**
* Various utility methods used by Spark.
*/
private[spark] object Utils extends Logging {
+ val osName = System.getProperty("os.name")
+
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -150,6 +155,7 @@ private[spark] object Utils extends Logging {
}
private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
+ private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
@@ -159,6 +165,14 @@ private[spark] object Utils extends Logging {
}
}
+ // Register the tachyon path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(tachyonfile: TachyonFile) {
+ val absolutePath = tachyonfile.getPath()
+ shutdownDeleteTachyonPaths.synchronized {
+ shutdownDeleteTachyonPaths += absolutePath
+ }
+ }
+
// Is the path already registered to be deleted via a shutdown hook ?
def hasShutdownDeleteDir(file: File): Boolean = {
val absolutePath = file.getAbsolutePath()
@@ -167,6 +181,14 @@ private[spark] object Utils extends Logging {
}
}
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = {
+ val absolutePath = file.getPath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.contains(absolutePath)
+ }
+ }
+
// Note: if file is child of some registered path, while not equal to it, then return true;
// else false. This is to ensure that two shutdown hooks do not try to delete each others
// paths - resulting in IOException and incomplete cleanup.
@@ -183,6 +205,22 @@ private[spark] object Utils extends Logging {
retval
}
+ // Note: if file is child of some registered path, while not equal to it, then return true;
+ // else false. This is to ensure that two shutdown hooks do not try to delete each others
+ // paths - resulting in Exception and incomplete cleanup.
+ def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = {
+ val absolutePath = file.getPath()
+ val retval = shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.find { path =>
+ !absolutePath.equals(path) && absolutePath.startsWith(path)
+ }.isDefined
+ }
+ if (retval) {
+ logInfo("path = " + file + ", already present as root for deletion.")
+ }
+ retval
+ }
+
/** Create a temporary directory inside the given parent directory */
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
var attempts = 0
@@ -521,9 +559,10 @@ private[spark] object Utils extends Logging {
/**
* Delete a file or directory and its contents recursively.
+ * Don't follow directories if they are symlinks.
*/
def deleteRecursively(file: File) {
- if (file.isDirectory) {
+ if ((file.isDirectory) && !isSymlink(file)) {
for (child <- listFilesSafely(file)) {
deleteRecursively(child)
}
@@ -536,6 +575,49 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * Delete a file or directory and its contents recursively.
+ */
+ def deleteRecursively(dir: TachyonFile, client: TachyonFS) {
+ if (!client.delete(dir.getPath(), true)) {
+ throw new IOException("Failed to delete the tachyon dir: " + dir)
+ }
+ }
+
+ /**
+ * Check to see if file is a symbolic link.
+ */
+ def isSymlink(file: File): Boolean = {
+ if (file == null) throw new NullPointerException("File must not be null")
+ if (osName.startsWith("Windows")) return false
+ val fileInCanonicalDir = if (file.getParent() == null) {
+ file
+ } else {
+ new File(file.getParentFile().getCanonicalFile(), file.getName())
+ }
+
+ if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) {
+ return false
+ } else {
+ return true
+ }
+ }
+
+ /**
+ * Finds all the files in a directory whose last modified time is older than cutoff seconds.
+ * @param dir must be the path to a directory, or IllegalArgumentException is thrown
+ * @param cutoff measured in seconds. Files older than this are returned.
+ */
+ def findOldFiles(dir: File, cutoff: Long): Seq[File] = {
+ val currentTimeMillis = System.currentTimeMillis
+ if (dir.isDirectory) {
+ val files = listFilesSafely(dir)
+ files.filter { file => file.lastModified < (currentTimeMillis - cutoff * 1000) }
+ } else {
+ throw new IllegalArgumentException(dir + " is not a directory!")
+ }
+ }
+
/**
* Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
*/
@@ -898,6 +980,26 @@ private[spark] object Utils extends Logging {
count
}
+ /**
+ * Creates a symlink. Note jdk1.7 has Files.createSymbolicLink but not used here
+ * for jdk1.6 support. Supports windows by doing copy, everything else uses "ln -sf".
+ * @param src absolute path to the source
+ * @param dst relative path for the destination
+ */
+ def symlink(src: File, dst: File) {
+ if (!src.isAbsolute()) {
+ throw new IOException("Source must be absolute")
+ }
+ if (dst.isAbsolute()) {
+ throw new IOException("Destination must be relative")
+ }
+ val linkCmd = if (osName.startsWith("Windows")) "copy" else "ln -sf"
+ import scala.sys.process._
+ (linkCmd + " " + src.getAbsolutePath() + " " + dst.getPath()) lines_! ProcessLogger(line =>
+ (logInfo(line)))
+ }
+
+
/** Return the class name of the given object, removing all dollar signs */
def getFormattedClassName(obj: AnyRef) = {
obj.getClass.getSimpleName.replace("$", "")
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
index d3153d2cac4a5..af1f64649f354 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -88,6 +88,45 @@ class BitSet(numBits: Int) extends Serializable {
newBS
}
+ /**
+ * Compute the symmetric difference by performing bit-wise XOR of the two sets returning the
+ * result.
+ */
+ def ^(other: BitSet): BitSet = {
+ val newBS = new BitSet(math.max(capacity, other.capacity))
+ val smaller = math.min(numWords, other.numWords)
+ var ind = 0
+ while (ind < smaller) {
+ newBS.words(ind) = words(ind) ^ other.words(ind)
+ ind += 1
+ }
+ if (ind < numWords) {
+ Array.copy( words, ind, newBS.words, ind, numWords - ind )
+ }
+ if (ind < other.numWords) {
+ Array.copy( other.words, ind, newBS.words, ind, other.numWords - ind )
+ }
+ newBS
+ }
+
+ /**
+ * Compute the difference of the two sets by performing bit-wise AND-NOT returning the
+ * result.
+ */
+ def andNot(other: BitSet): BitSet = {
+ val newBS = new BitSet(capacity)
+ val smaller = math.min(numWords, other.numWords)
+ var ind = 0
+ while (ind < smaller) {
+ newBS.words(ind) = words(ind) & ~other.words(ind)
+ ind += 1
+ }
+ if (ind < numWords) {
+ Array.copy( words, ind, newBS.words, ind, numWords - ind )
+ }
+ newBS
+ }
+
/**
* Sets the bit at the specified index to true.
* @param index the bit index
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index c6b65c7348ae0..762405be2a8f9 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -17,13 +17,12 @@
package org.apache.spark;
-import java.io.File;
-import java.io.IOException;
-import java.io.Serializable;
+import java.io.*;
import java.util.*;
import scala.Tuple2;
+import com.google.common.collect.Lists;
import com.google.common.base.Optional;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
@@ -181,6 +180,14 @@ public void call(String s) {
Assert.assertEquals(2, foreachCalls);
}
+ @Test
+ public void toLocalIterator() {
+ List correct = Arrays.asList(1, 2, 3, 4);
+ JavaRDD rdd = sc.parallelize(correct);
+ List result = Lists.newArrayList(rdd.toLocalIterator());
+ Assert.assertTrue(correct.equals(result));
+ }
+
@SuppressWarnings("unchecked")
@Test
public void lookup() {
@@ -599,6 +606,32 @@ public void textFiles() throws IOException {
Assert.assertEquals(expected, readRDD.collect());
}
+ @Test
+ public void wholeTextFiles() throws IOException {
+ byte[] content1 = "spark is easy to use.\n".getBytes();
+ byte[] content2 = "spark is also easy to use.\n".getBytes();
+
+ File tempDir = Files.createTempDir();
+ String tempDirName = tempDir.getAbsolutePath();
+ DataOutputStream ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00000"));
+ ds.write(content1);
+ ds.close();
+ ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00001"));
+ ds.write(content2);
+ ds.close();
+
+ HashMap container = new HashMap();
+ container.put(tempDirName+"/part-00000", new Text(content1).toString());
+ container.put(tempDirName+"/part-00001", new Text(content2).toString());
+
+ JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName);
+ List> result = readRDD.collect();
+
+ for (Tuple2 res : result) {
+ Assert.assertEquals(res._2(), container.get(res._1()));
+ }
+ }
+
@Test
public void textFilesCompressed() throws IOException {
File tempDir = Files.createTempDir();
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 6c73ea6949dd2..4e7c34e6d1ada 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -66,7 +66,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
test ("add value to collection accumulators") {
val maxI = 1000
- for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ for (nThreads <- List(1, 10)) { // test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
val d = sc.parallelize(1 to maxI)
@@ -83,7 +83,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
test ("value not readable in tasks") {
val maxI = 1000
- for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ for (nThreads <- List(1, 10)) { // test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
val d = sc.parallelize(1 to maxI)
@@ -124,7 +124,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
test ("localValue readable in tasks") {
val maxI = 1000
- for (nThreads <- List(1, 10)) { //test single & multi-threaded
+ for (nThreads <- List(1, 10)) { // test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d2e29f20f0b08..d2555b7c052c1 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -432,7 +432,6 @@ object CheckpointSuite {
// This is a custom cogroup function that does not use mapValues like
// the PairRDDFunctions.cogroup()
def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
- //println("First = " + first + ", second = " + second)
new CoGroupedRDD[K](
Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]),
part
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 996db70809320..7c30626a0c421 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -146,7 +146,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array"))
// We can't catch all usages of arrays, since they might occur inside other collections:
- //assert(fails { arrPairs.distinct() })
+ // assert(fails { arrPairs.distinct() })
assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
index 6e7fd55fa4bb1..627e9b5cd9060 100644
--- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
@@ -17,8 +17,11 @@
package org.apache.spark
-import org.scalatest.FunSuite
+import java.io.File
+
+import com.google.common.io.Files
+import org.scalatest.FunSuite
import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition}
import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
@@ -126,6 +129,29 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("basic pipe with separate working directory") {
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("cat"), separateWorkingDir = true)
+ val c = piped.collect()
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true)
+ val collectPwd = pipedPwd.collect()
+ assert(collectPwd(0).contains("tasks/"))
+ val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true).collect()
+ // make sure symlinks were created
+ assert(pipedLs.length > 0)
+ // clean up top level tasks directory
+ new File("tasks").delete()
+ } else {
+ assert(true)
+ }
+ }
+
test("test pipe exports map_input_file") {
testExportInputFile("map_input_file")
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index b543471a5d35b..94fba102865b3 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -51,6 +51,14 @@ class SparkContextSchedulerCreationSuite
}
}
+ test("local-*") {
+ val sched = createTaskScheduler("local[*]")
+ sched.backend match {
+ case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors())
+ case _ => fail()
+ }
+ }
+
test("local-n") {
val sched = createTaskScheduler("local[5]")
assert(sched.maxTaskFailures === 1)
diff --git a/core/src/test/scala/org/apache/spark/TestUtils.scala b/core/src/test/scala/org/apache/spark/TestUtils.scala
new file mode 100644
index 0000000000000..1611d09652d40
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/TestUtils.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.net.{URI, URL}
+import java.util.jar.{JarEntry, JarOutputStream}
+
+import scala.collection.JavaConversions._
+
+import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
+import com.google.common.io.Files
+
+object TestUtils {
+
+ /**
+ * Create a jar that defines classes with the given names.
+ *
+ * Note: if this is used during class loader tests, class names should be unique
+ * in order to avoid interference between tests.
+ */
+ def createJarWithClasses(classNames: Seq[String]): URL = {
+ val tempDir = Files.createTempDir()
+ val files = for (name <- classNames) yield createCompiledClass(name, tempDir)
+ val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
+ createJar(files, jarFile)
+ }
+
+ /**
+ * Create a jar file that contains this set of files. All files will be located at the root
+ * of the jar.
+ */
+ def createJar(files: Seq[File], jarFile: File): URL = {
+ val jarFileStream = new FileOutputStream(jarFile)
+ val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest())
+
+ for (file <- files) {
+ val jarEntry = new JarEntry(file.getName)
+ jarStream.putNextEntry(jarEntry)
+
+ val in = new FileInputStream(file)
+ val buffer = new Array[Byte](10240)
+ var nRead = 0
+ while (nRead <= 0) {
+ nRead = in.read(buffer, 0, buffer.length)
+ jarStream.write(buffer, 0, nRead)
+ }
+ in.close()
+ }
+ jarStream.close()
+ jarFileStream.close()
+
+ jarFile.toURI.toURL
+ }
+
+ // Adapted from the JavaCompiler.java doc examples
+ private val SOURCE = JavaFileObject.Kind.SOURCE
+ private def createURI(name: String) = {
+ URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}")
+ }
+
+ private class JavaSourceFromString(val name: String, val code: String)
+ extends SimpleJavaFileObject(createURI(name), SOURCE) {
+ override def getCharContent(ignoreEncodingErrors: Boolean) = code
+ }
+
+ /** Creates a compiled class with the given name. Class file will be placed in destDir. */
+ def createCompiledClass(className: String, destDir: File): File = {
+ val compiler = ToolProvider.getSystemJavaCompiler
+ val sourceFile = new JavaSourceFromString(className, s"public class $className {}")
+
+ // Calling this outputs a class file in pwd. It's easier to just rename the file than
+ // build a custom FileManager that controls the output location.
+ compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call()
+
+ val fileName = className + ".class"
+ val result = new File(fileName)
+ if (!result.exists()) throw new Exception("Compiled file not found: " + fileName)
+ val out = new File(destDir, fileName)
+ result.renameTo(out)
+ out
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
new file mode 100644
index 0000000000000..4e489cd9b66a6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io.{OutputStream, PrintStream}
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.deploy.SparkSubmit._
+
+
+class SparkSubmitSuite extends FunSuite with ShouldMatchers {
+
+ val noOpOutputStream = new OutputStream {
+ def write(b: Int) = {}
+ }
+
+ /** Simple PrintStream that reads data into a buffer */
+ class BufferPrintStream extends PrintStream(noOpOutputStream) {
+ var lineBuffer = ArrayBuffer[String]()
+ override def println(line: String) {
+ lineBuffer += line
+ }
+ }
+
+ /** Returns true if the script exits and the given search string is printed. */
+ def testPrematureExit(input: Array[String], searchString: String): Boolean = {
+ val printStream = new BufferPrintStream()
+ SparkSubmit.printStream = printStream
+
+ @volatile var exitedCleanly = false
+ SparkSubmit.exitFn = () => exitedCleanly = true
+
+ val thread = new Thread {
+ override def run() = try {
+ SparkSubmit.main(input)
+ } catch {
+ // If exceptions occur after the "exit" has happened, fine to ignore them.
+ // These represent code paths not reachable during normal execution.
+ case e: Exception => if (!exitedCleanly) throw e
+ }
+ }
+ thread.start()
+ thread.join()
+ printStream.lineBuffer.find(s => s.contains(searchString)).size > 0
+ }
+
+ test("prints usage on empty input") {
+ testPrematureExit(Array[String](), "Usage: spark-submit") should be (true)
+ }
+
+ test("prints usage with only --help") {
+ testPrematureExit(Array("--help"), "Usage: spark-submit") should be (true)
+ }
+
+ test("handles multiple binary definitions") {
+ val adjacentJars = Array("foo.jar", "bar.jar")
+ testPrematureExit(adjacentJars, "error: Found two conflicting resources") should be (true)
+
+ val nonAdjacentJars =
+ Array("foo.jar", "--master", "123", "--class", "abc", "bar.jar")
+ testPrematureExit(nonAdjacentJars, "error: Found two conflicting resources") should be (true)
+ }
+
+ test("handle binary specified but not class") {
+ testPrematureExit(Array("foo.jar"), "must specify a main class")
+ }
+
+ test("handles YARN cluster mode") {
+ val clArgs = Array("thejar.jar", "--deploy-mode", "cluster",
+ "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5",
+ "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar",
+ "--arg", "arg1", "--arg", "arg2", "--driver-memory", "4g",
+ "--queue", "thequeue", "--files", "file1.txt,file2.txt",
+ "--archives", "archive1.txt,archive2.txt", "--num-executors", "6")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ val childArgsStr = childArgs.mkString(" ")
+ childArgsStr should include ("--jar thejar.jar")
+ childArgsStr should include ("--class org.SomeClass")
+ childArgsStr should include ("--addJars one.jar,two.jar,three.jar")
+ childArgsStr should include ("--executor-memory 5g")
+ childArgsStr should include ("--driver-memory 4g")
+ childArgsStr should include ("--executor-cores 5")
+ childArgsStr should include ("--args arg1 --args arg2")
+ childArgsStr should include ("--queue thequeue")
+ childArgsStr should include ("--files file1.txt,file2.txt")
+ childArgsStr should include ("--archives archive1.txt,archive2.txt")
+ childArgsStr should include ("--num-executors 6")
+ mainClass should be ("org.apache.spark.deploy.yarn.Client")
+ classpath should have length (0)
+ sysProps should have size (0)
+ }
+
+ test("handles YARN client mode") {
+ val clArgs = Array("thejar.jar", "--deploy-mode", "client",
+ "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5",
+ "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar",
+ "--arg", "arg1", "--arg", "arg2", "--driver-memory", "4g",
+ "--queue", "thequeue", "--files", "file1.txt,file2.txt",
+ "--archives", "archive1.txt,archive2.txt", "--num-executors", "6")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ childArgs.mkString(" ") should be ("arg1 arg2")
+ mainClass should be ("org.SomeClass")
+ classpath should contain ("thejar.jar")
+ classpath should contain ("one.jar")
+ classpath should contain ("two.jar")
+ classpath should contain ("three.jar")
+ sysProps("spark.executor.memory") should be ("5g")
+ sysProps("spark.executor.cores") should be ("5")
+ sysProps("spark.yarn.queue") should be ("thequeue")
+ sysProps("spark.yarn.dist.files") should be ("file1.txt,file2.txt")
+ sysProps("spark.yarn.dist.archives") should be ("archive1.txt,archive2.txt")
+ sysProps("spark.executor.instances") should be ("6")
+ }
+
+ test("handles standalone cluster mode") {
+ val clArgs = Array("thejar.jar", "--deploy-mode", "cluster",
+ "--master", "spark://h:p", "--class", "org.SomeClass", "--arg", "arg1", "--arg", "arg2",
+ "--supervise", "--driver-memory", "4g", "--driver-cores", "5")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ val childArgsStr = childArgs.mkString(" ")
+ print("child args: " + childArgsStr)
+ childArgsStr.startsWith("--memory 4g --cores 5 --supervise") should be (true)
+ childArgsStr should include ("launch spark://h:p thejar.jar org.SomeClass arg1 arg2")
+ mainClass should be ("org.apache.spark.deploy.Client")
+ classpath should have length (0)
+ sysProps should have size (0)
+ }
+
+ test("handles standalone client mode") {
+ val clArgs = Array("thejar.jar", "--deploy-mode", "client",
+ "--master", "spark://h:p", "--executor-memory", "5g", "--total-executor-cores", "5",
+ "--class", "org.SomeClass", "--arg", "arg1", "--arg", "arg2",
+ "--driver-memory", "4g")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ childArgs.mkString(" ") should be ("arg1 arg2")
+ mainClass should be ("org.SomeClass")
+ classpath should contain ("thejar.jar")
+ sysProps("spark.executor.memory") should be ("5g")
+ sysProps("spark.cores.max") should be ("5")
+ }
+
+ test("handles mesos client mode") {
+ val clArgs = Array("thejar.jar", "--deploy-mode", "client",
+ "--master", "mesos://h:p", "--executor-memory", "5g", "--total-executor-cores", "5",
+ "--class", "org.SomeClass", "--arg", "arg1", "--arg", "arg2",
+ "--driver-memory", "4g")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
+ childArgs.mkString(" ") should be ("arg1 arg2")
+ mainClass should be ("org.SomeClass")
+ classpath should contain ("thejar.jar")
+ sysProps("spark.executor.memory") should be ("5g")
+ sysProps("spark.cores.max") should be ("5")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
new file mode 100644
index 0000000000000..09e35bfc8f85f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.DataOutputStream
+import java.io.File
+import java.io.FileOutputStream
+
+import scala.collection.immutable.IndexedSeq
+
+import com.google.common.io.Files
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.io.Text
+
+import org.apache.spark.SparkContext
+
+/**
+ * Tests the correctness of
+ * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary
+ * directory is created as fake input. Temporal storage would be deleted in the end.
+ */
+class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll {
+ private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+
+ // Set the block size of local file system to test whether files are split right or not.
+ sc.hadoopConfiguration.setLong("fs.local.block.size", 32)
+ }
+
+ override def afterAll() {
+ sc.stop()
+ }
+
+ private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte]) = {
+ val out = new DataOutputStream(new FileOutputStream(s"${inputDir.toString}/$fileName"))
+ out.write(contents, 0, contents.length)
+ out.close()
+ }
+
+ /**
+ * This code will test the behaviors of WholeTextFileRecordReader based on local disk. There are
+ * three aspects to check:
+ * 1) Whether all files are read;
+ * 2) Whether paths are read correctly;
+ * 3) Does the contents be the same.
+ */
+ test("Correctness of WholeTextFileRecordReader.") {
+
+ val dir = Files.createTempDir()
+ println(s"Local disk address is ${dir.toString}.")
+
+ WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
+ createNativeFile(dir, filename, contents)
+ }
+
+ val res = sc.wholeTextFiles(dir.toString).collect()
+
+ assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size,
+ "Number of files read out does not fit with the actual value.")
+
+ for ((filename, contents) <- res) {
+ val shortName = filename.split('/').last
+ assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName),
+ s"Missing file name $filename.")
+ assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString,
+ s"file $filename contents can not match.")
+ }
+
+ dir.delete()
+ }
+}
+
+/**
+ * Files to be tested are defined here.
+ */
+object WholeTextFileRecordReaderSuite {
+ private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte)
+
+ private val fileNames = Array("part-00000", "part-00001", "part-00002")
+ private val fileLengths = Array(10, 100, 1000)
+
+ private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) =>
+ filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray
+ }.toMap
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index d6b5fdc7984b4..25973348a7837 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -33,6 +33,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
test("basic operations") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
+ assert(nums.toLocalIterator.toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
assert(dups.distinct().count() === 4)
assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses?
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index c97543f57d8f3..ce567b0cde85d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -428,7 +428,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assert(scheduler.pendingTasks.isEmpty)
assert(scheduler.activeJobs.isEmpty)
assert(scheduler.failedStages.isEmpty)
- assert(scheduler.stageIdToActiveJob.isEmpty)
+ assert(scheduler.jobIdToActiveJob.isEmpty)
assert(scheduler.jobIdToStageIds.isEmpty)
assert(scheduler.stageIdToJobIds.isEmpty)
assert(scheduler.stageIdToStage.isEmpty)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index a25ce35736146..7c843772bc2e0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -111,7 +111,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
- //just to make sure some of the tasks take a noticeable amount of time
+ // just to make sure some of the tasks take a noticeable amount of time
val w = {i:Int =>
if (i == 0)
Thread.sleep(100)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index e83cd55e73691..b6dd0526105a0 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -96,9 +96,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("StorageLevel object caching") {
- val level1 = StorageLevel(false, false, false, 3)
- val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
- val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ val level1 = StorageLevel(false, false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, false, 2) // this should return a different object
assert(level2 === level1, "level2 is not same as level1")
assert(level2.eq(level1), "level2 is not the same object as level1")
assert(level3 != level1, "level3 is same as level1")
@@ -410,6 +410,25 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store")
}
+ test("tachyon storage") {
+ // TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar.
+ val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false)
+ if (tachyonUnitTestEnabled) {
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.OFF_HEAP)
+ store.putSingle("a2", a2, StorageLevel.OFF_HEAP)
+ store.putSingle("a3", a3, StorageLevel.OFF_HEAP)
+ assert(store.getSingle("a3").isDefined, "a3 was in store")
+ assert(store.getSingle("a2").isDefined, "a2 was in store")
+ assert(store.getSingle("a1").isDefined, "a1 was in store")
+ } else {
+ info("tachyon storage test disabled.")
+ }
+ }
+
test("on-disk storage") {
store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index d8a3e859f85cd..beac656f573b4 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -18,16 +18,45 @@
package org.apache.spark.ui.jobs
import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.{LocalSparkContext, SparkContext, Success}
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, Success}
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
-class JobProgressListenerSuite extends FunSuite with LocalSparkContext {
+class JobProgressListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+ test("test LRU eviction of stages") {
+ val conf = new SparkConf()
+ conf.set("spark.ui.retainedStages", 5.toString)
+ val listener = new JobProgressListener(conf)
+
+ def createStageStartEvent(stageId: Int) = {
+ val stageInfo = new StageInfo(stageId, stageId.toString, 0, null)
+ SparkListenerStageSubmitted(stageInfo)
+ }
+
+ def createStageEndEvent(stageId: Int) = {
+ val stageInfo = new StageInfo(stageId, stageId.toString, 0, null)
+ SparkListenerStageCompleted(stageInfo)
+ }
+
+ for (i <- 1 to 50) {
+ listener.onStageSubmitted(createStageStartEvent(i))
+ listener.onStageCompleted(createStageEndEvent(i))
+ }
+
+ listener.completedStages.size should be (5)
+ listener.completedStages.filter(_.stageId == 50).size should be (1)
+ listener.completedStages.filter(_.stageId == 49).size should be (1)
+ listener.completedStages.filter(_.stageId == 48).size should be (1)
+ listener.completedStages.filter(_.stageId == 47).size should be (1)
+ listener.completedStages.filter(_.stageId == 46).size should be (1)
+ }
+
test("test executor id to summary") {
- val sc = new SparkContext("local", "test")
- val listener = new JobProgressListener(sc.conf)
+ val conf = new SparkConf()
+ val listener = new JobProgressListener(conf)
val taskMetrics = new TaskMetrics()
val shuffleReadMetrics = new ShuffleReadMetrics()
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 67c0a434c9b52..054eb01a64c11 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -112,7 +112,6 @@ class JsonProtocolSuite extends FunSuite {
testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark"))
testBlockId(TaskResultBlockId(1L))
testBlockId(StreamBlockId(1, 2L))
- testBlockId(TempBlockId(UUID.randomUUID()))
}
@@ -168,8 +167,8 @@ class JsonProtocolSuite extends FunSuite {
}
private def testBlockId(blockId: BlockId) {
- val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId))
- blockId == newBlockId
+ val newBlockId = BlockId(blockId.toString)
+ assert(blockId === newBlockId)
}
@@ -180,90 +179,90 @@ class JsonProtocolSuite extends FunSuite {
private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) {
(event1, event2) match {
case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) =>
- assert(e1.properties == e2.properties)
+ assert(e1.properties === e2.properties)
assertEquals(e1.stageInfo, e2.stageInfo)
case (e1: SparkListenerStageCompleted, e2: SparkListenerStageCompleted) =>
assertEquals(e1.stageInfo, e2.stageInfo)
case (e1: SparkListenerTaskStart, e2: SparkListenerTaskStart) =>
- assert(e1.stageId == e2.stageId)
+ assert(e1.stageId === e2.stageId)
assertEquals(e1.taskInfo, e2.taskInfo)
case (e1: SparkListenerTaskGettingResult, e2: SparkListenerTaskGettingResult) =>
assertEquals(e1.taskInfo, e2.taskInfo)
case (e1: SparkListenerTaskEnd, e2: SparkListenerTaskEnd) =>
- assert(e1.stageId == e2.stageId)
- assert(e1.taskType == e2.taskType)
+ assert(e1.stageId === e2.stageId)
+ assert(e1.taskType === e2.taskType)
assertEquals(e1.reason, e2.reason)
assertEquals(e1.taskInfo, e2.taskInfo)
assertEquals(e1.taskMetrics, e2.taskMetrics)
case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) =>
- assert(e1.jobId == e2.jobId)
- assert(e1.properties == e2.properties)
- assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 == i2))
+ assert(e1.jobId === e2.jobId)
+ assert(e1.properties === e2.properties)
+ assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2))
case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) =>
- assert(e1.jobId == e2.jobId)
+ assert(e1.jobId === e2.jobId)
assertEquals(e1.jobResult, e2.jobResult)
case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) =>
assertEquals(e1.environmentDetails, e2.environmentDetails)
case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) =>
- assert(e1.maxMem == e2.maxMem)
+ assert(e1.maxMem === e2.maxMem)
assertEquals(e1.blockManagerId, e2.blockManagerId)
case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) =>
assertEquals(e1.blockManagerId, e2.blockManagerId)
case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) =>
- assert(e1.rddId == e2.rddId)
+ assert(e1.rddId === e2.rddId)
case (SparkListenerShutdown, SparkListenerShutdown) =>
case _ => fail("Events don't match in types!")
}
}
private def assertEquals(info1: StageInfo, info2: StageInfo) {
- assert(info1.stageId == info2.stageId)
- assert(info1.name == info2.name)
- assert(info1.numTasks == info2.numTasks)
- assert(info1.submissionTime == info2.submissionTime)
- assert(info1.completionTime == info2.completionTime)
- assert(info1.emittedTaskSizeWarning == info2.emittedTaskSizeWarning)
+ assert(info1.stageId === info2.stageId)
+ assert(info1.name === info2.name)
+ assert(info1.numTasks === info2.numTasks)
+ assert(info1.submissionTime === info2.submissionTime)
+ assert(info1.completionTime === info2.completionTime)
+ assert(info1.emittedTaskSizeWarning === info2.emittedTaskSizeWarning)
assertEquals(info1.rddInfo, info2.rddInfo)
}
private def assertEquals(info1: RDDInfo, info2: RDDInfo) {
- assert(info1.id == info2.id)
- assert(info1.name == info2.name)
- assert(info1.numPartitions == info2.numPartitions)
- assert(info1.numCachedPartitions == info2.numCachedPartitions)
- assert(info1.memSize == info2.memSize)
- assert(info1.diskSize == info2.diskSize)
+ assert(info1.id === info2.id)
+ assert(info1.name === info2.name)
+ assert(info1.numPartitions === info2.numPartitions)
+ assert(info1.numCachedPartitions === info2.numCachedPartitions)
+ assert(info1.memSize === info2.memSize)
+ assert(info1.diskSize === info2.diskSize)
assertEquals(info1.storageLevel, info2.storageLevel)
}
private def assertEquals(level1: StorageLevel, level2: StorageLevel) {
- assert(level1.useDisk == level2.useDisk)
- assert(level1.useMemory == level2.useMemory)
- assert(level1.deserialized == level2.deserialized)
- assert(level1.replication == level2.replication)
+ assert(level1.useDisk === level2.useDisk)
+ assert(level1.useMemory === level2.useMemory)
+ assert(level1.deserialized === level2.deserialized)
+ assert(level1.replication === level2.replication)
}
private def assertEquals(info1: TaskInfo, info2: TaskInfo) {
- assert(info1.taskId == info2.taskId)
- assert(info1.index == info2.index)
- assert(info1.launchTime == info2.launchTime)
- assert(info1.executorId == info2.executorId)
- assert(info1.host == info2.host)
- assert(info1.taskLocality == info2.taskLocality)
- assert(info1.gettingResultTime == info2.gettingResultTime)
- assert(info1.finishTime == info2.finishTime)
- assert(info1.failed == info2.failed)
- assert(info1.serializedSize == info2.serializedSize)
+ assert(info1.taskId === info2.taskId)
+ assert(info1.index === info2.index)
+ assert(info1.launchTime === info2.launchTime)
+ assert(info1.executorId === info2.executorId)
+ assert(info1.host === info2.host)
+ assert(info1.taskLocality === info2.taskLocality)
+ assert(info1.gettingResultTime === info2.gettingResultTime)
+ assert(info1.finishTime === info2.finishTime)
+ assert(info1.failed === info2.failed)
+ assert(info1.serializedSize === info2.serializedSize)
}
private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
- assert(metrics1.hostname == metrics2.hostname)
- assert(metrics1.executorDeserializeTime == metrics2.executorDeserializeTime)
- assert(metrics1.resultSize == metrics2.resultSize)
- assert(metrics1.jvmGCTime == metrics2.jvmGCTime)
- assert(metrics1.resultSerializationTime == metrics2.resultSerializationTime)
- assert(metrics1.memoryBytesSpilled == metrics2.memoryBytesSpilled)
- assert(metrics1.diskBytesSpilled == metrics2.diskBytesSpilled)
+ assert(metrics1.hostname === metrics2.hostname)
+ assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime)
+ assert(metrics1.resultSize === metrics2.resultSize)
+ assert(metrics1.jvmGCTime === metrics2.jvmGCTime)
+ assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime)
+ assert(metrics1.memoryBytesSpilled === metrics2.memoryBytesSpilled)
+ assert(metrics1.diskBytesSpilled === metrics2.diskBytesSpilled)
assertOptionEquals(
metrics1.shuffleReadMetrics, metrics2.shuffleReadMetrics, assertShuffleReadEquals)
assertOptionEquals(
@@ -272,31 +271,31 @@ class JsonProtocolSuite extends FunSuite {
}
private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) {
- assert(metrics1.shuffleFinishTime == metrics2.shuffleFinishTime)
- assert(metrics1.totalBlocksFetched == metrics2.totalBlocksFetched)
- assert(metrics1.remoteBlocksFetched == metrics2.remoteBlocksFetched)
- assert(metrics1.localBlocksFetched == metrics2.localBlocksFetched)
- assert(metrics1.fetchWaitTime == metrics2.fetchWaitTime)
- assert(metrics1.remoteBytesRead == metrics2.remoteBytesRead)
+ assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime)
+ assert(metrics1.totalBlocksFetched === metrics2.totalBlocksFetched)
+ assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched)
+ assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched)
+ assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime)
+ assert(metrics1.remoteBytesRead === metrics2.remoteBytesRead)
}
private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) {
- assert(metrics1.shuffleBytesWritten == metrics2.shuffleBytesWritten)
- assert(metrics1.shuffleWriteTime == metrics2.shuffleWriteTime)
+ assert(metrics1.shuffleBytesWritten === metrics2.shuffleBytesWritten)
+ assert(metrics1.shuffleWriteTime === metrics2.shuffleWriteTime)
}
private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) {
- assert(bm1.executorId == bm2.executorId)
- assert(bm1.host == bm2.host)
- assert(bm1.port == bm2.port)
- assert(bm1.nettyPort == bm2.nettyPort)
+ assert(bm1.executorId === bm2.executorId)
+ assert(bm1.host === bm2.host)
+ assert(bm1.port === bm2.port)
+ assert(bm1.nettyPort === bm2.nettyPort)
}
private def assertEquals(result1: JobResult, result2: JobResult) {
(result1, result2) match {
case (JobSucceeded, JobSucceeded) =>
case (r1: JobFailed, r2: JobFailed) =>
- assert(r1.failedStageId == r2.failedStageId)
+ assert(r1.failedStageId === r2.failedStageId)
assertEquals(r1.exception, r2.exception)
case _ => fail("Job results don't match in types!")
}
@@ -307,13 +306,13 @@ class JsonProtocolSuite extends FunSuite {
case (Success, Success) =>
case (Resubmitted, Resubmitted) =>
case (r1: FetchFailed, r2: FetchFailed) =>
- assert(r1.shuffleId == r2.shuffleId)
- assert(r1.mapId == r2.mapId)
- assert(r1.reduceId == r2.reduceId)
+ assert(r1.shuffleId === r2.shuffleId)
+ assert(r1.mapId === r2.mapId)
+ assert(r1.reduceId === r2.reduceId)
assertEquals(r1.bmAddress, r2.bmAddress)
case (r1: ExceptionFailure, r2: ExceptionFailure) =>
- assert(r1.className == r2.className)
- assert(r1.description == r2.description)
+ assert(r1.className === r2.className)
+ assert(r1.description === r2.description)
assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals)
assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals)
case (TaskResultLost, TaskResultLost) =>
@@ -329,13 +328,13 @@ class JsonProtocolSuite extends FunSuite {
details2: Map[String, Seq[(String, String)]]) {
details1.zip(details2).foreach {
case ((key1, values1: Seq[(String, String)]), (key2, values2: Seq[(String, String)])) =>
- assert(key1 == key2)
- values1.zip(values2).foreach { case (v1, v2) => assert(v1 == v2) }
+ assert(key1 === key2)
+ values1.zip(values2).foreach { case (v1, v2) => assert(v1 === v2) }
}
}
private def assertEquals(exception1: Exception, exception2: Exception) {
- assert(exception1.getMessage == exception2.getMessage)
+ assert(exception1.getMessage === exception2.getMessage)
assertSeqEquals(
exception1.getStackTrace,
exception2.getStackTrace,
@@ -344,11 +343,11 @@ class JsonProtocolSuite extends FunSuite {
private def assertJsonStringEquals(json1: String, json2: String) {
val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "")
- formatJsonString(json1) == formatJsonString(json2)
+ formatJsonString(json1) === formatJsonString(json2)
}
private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) {
- assert(seq1.length == seq2.length)
+ assert(seq1.length === seq2.length)
seq1.zip(seq2).foreach { case (t1, t2) =>
assertEquals(t1, t2)
}
@@ -389,11 +388,11 @@ class JsonProtocolSuite extends FunSuite {
}
private def assertBlockEquals(b1: (BlockId, BlockStatus), b2: (BlockId, BlockStatus)) {
- assert(b1 == b2)
+ assert(b1 === b2)
}
private def assertStackTraceElementEquals(ste1: StackTraceElement, ste2: StackTraceElement) {
- assert(ste1 == ste2)
+ assert(ste1 === ste2)
}
@@ -457,7 +456,7 @@ class JsonProtocolSuite extends FunSuite {
t.shuffleWriteMetrics = Some(sw)
// Make at most 6 blocks
t.updatedBlocks = Some((1 to (e % 5 + 1)).map { i =>
- (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i))
+ (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i, c%i))
}.toSeq)
t
}
@@ -471,19 +470,19 @@ class JsonProtocolSuite extends FunSuite {
"""
{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":100,"Stage Name":
"greetings","Number of Tasks":200,"RDD Info":{"RDD ID":100,"Name":"mayor","Storage
- Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},
- "Number of Partitions":200,"Number of Cached Partitions":300,"Memory Size":400,
- "Disk Size":500},"Emitted Task Size Warning":false},"Properties":{"France":"Paris",
- "Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}}
+ Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true,
+ "Replication":1},"Number of Partitions":200,"Number of Cached Partitions":300,
+ "Memory Size":400,"Disk Size":500,"Tachyon Size":0},"Emitted Task Size Warning":false},
+ "Properties":{"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}}
"""
private val stageCompletedJsonString =
"""
{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":101,"Stage Name":
"greetings","Number of Tasks":201,"RDD Info":{"RDD ID":101,"Name":"mayor","Storage
- Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1},
- "Number of Partitions":201,"Number of Cached Partitions":301,"Memory Size":401,
- "Disk Size":501},"Emitted Task Size Warning":false}}
+ Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true,
+ "Replication":1},"Number of Partitions":201,"Number of Cached Partitions":301,
+ "Memory Size":401,"Disk Size":501,"Tachyon Size":0},"Emitted Task Size Warning":false}}
"""
private val taskStartJsonString =
@@ -516,8 +515,8 @@ class JsonProtocolSuite extends FunSuite {
700,"Fetch Wait Time":900,"Remote Bytes Read":1000},"Shuffle Write Metrics":
{"Shuffle Bytes Written":1200,"Shuffle Write Time":1500},"Updated Blocks":
[{"Block ID":{"Type":"RDDBlockId","RDD ID":0,"Split Index":0},"Status":
- {"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":false,
- "Replication":2},"Memory Size":0,"Disk Size":0}}]}}
+ {"Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":false,
+ "Replication":2},"Memory Size":0,"Disk Size":0,"Tachyon Size":0}}]}}
"""
private val jobStartJsonString =
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index eb8f5915605de..eb7fb6318262b 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.util
import scala.util.Random
-import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
+import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
import java.nio.{ByteBuffer, ByteOrder}
import com.google.common.base.Charsets
@@ -39,7 +39,7 @@ class UtilsSuite extends FunSuite {
}
test("copyStream") {
- //input array initialization
+ // input array initialization
val bytes = Array.ofDim[Byte](9000)
Random.nextBytes(bytes)
@@ -154,5 +154,18 @@ class UtilsSuite extends FunSuite {
val iterator = Iterator.range(0, 5)
assert(Utils.getIteratorSize(iterator) === 5L)
}
+
+ test("findOldFiles") {
+ // create some temporary directories and files
+ val parent: File = Utils.createTempDir()
+ val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories
+ val child2: File = Utils.createTempDir(parent.getCanonicalPath)
+ // set the last modified time of child1 to 10 secs old
+ child1.setLastModified(System.currentTimeMillis() - (1000 * 10))
+
+ val result = Utils.findOldFiles(parent, 5) // find files older than 5 secs
+ assert(result.size.equals(1))
+ assert(result(0).getCanonicalPath.equals(child1.getCanonicalPath))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
index c32183c134f9c..b85a409a4b2e9 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
@@ -69,4 +69,87 @@ class BitSetSuite extends FunSuite {
assert(bitset.nextSetBit(96) === 96)
assert(bitset.nextSetBit(97) === -1)
}
+
+ test( "xor len(bitsetX) < len(bitsetY)" ) {
+ val setBitsX = Seq( 0, 2, 3, 37, 41 )
+ val setBitsY = Seq( 0, 1, 3, 37, 38, 41, 85)
+ val bitsetX = new BitSet(60)
+ setBitsX.foreach( i => bitsetX.set(i))
+ val bitsetY = new BitSet(100)
+ setBitsY.foreach( i => bitsetY.set(i))
+
+ val bitsetXor = bitsetX ^ bitsetY
+
+ assert(bitsetXor.nextSetBit(0) === 1)
+ assert(bitsetXor.nextSetBit(1) === 1)
+ assert(bitsetXor.nextSetBit(2) === 2)
+ assert(bitsetXor.nextSetBit(3) === 38)
+ assert(bitsetXor.nextSetBit(38) === 38)
+ assert(bitsetXor.nextSetBit(39) === 85)
+ assert(bitsetXor.nextSetBit(42) === 85)
+ assert(bitsetXor.nextSetBit(85) === 85)
+ assert(bitsetXor.nextSetBit(86) === -1)
+
+ }
+
+ test( "xor len(bitsetX) > len(bitsetY)" ) {
+ val setBitsX = Seq( 0, 1, 3, 37, 38, 41, 85)
+ val setBitsY = Seq( 0, 2, 3, 37, 41 )
+ val bitsetX = new BitSet(100)
+ setBitsX.foreach( i => bitsetX.set(i))
+ val bitsetY = new BitSet(60)
+ setBitsY.foreach( i => bitsetY.set(i))
+
+ val bitsetXor = bitsetX ^ bitsetY
+
+ assert(bitsetXor.nextSetBit(0) === 1)
+ assert(bitsetXor.nextSetBit(1) === 1)
+ assert(bitsetXor.nextSetBit(2) === 2)
+ assert(bitsetXor.nextSetBit(3) === 38)
+ assert(bitsetXor.nextSetBit(38) === 38)
+ assert(bitsetXor.nextSetBit(39) === 85)
+ assert(bitsetXor.nextSetBit(42) === 85)
+ assert(bitsetXor.nextSetBit(85) === 85)
+ assert(bitsetXor.nextSetBit(86) === -1)
+
+ }
+
+ test( "andNot len(bitsetX) < len(bitsetY)" ) {
+ val setBitsX = Seq( 0, 2, 3, 37, 41, 48 )
+ val setBitsY = Seq( 0, 1, 3, 37, 38, 41, 85)
+ val bitsetX = new BitSet(60)
+ setBitsX.foreach( i => bitsetX.set(i))
+ val bitsetY = new BitSet(100)
+ setBitsY.foreach( i => bitsetY.set(i))
+
+ val bitsetDiff = bitsetX.andNot( bitsetY )
+
+ assert(bitsetDiff.nextSetBit(0) === 2)
+ assert(bitsetDiff.nextSetBit(1) === 2)
+ assert(bitsetDiff.nextSetBit(2) === 2)
+ assert(bitsetDiff.nextSetBit(3) === 48)
+ assert(bitsetDiff.nextSetBit(48) === 48)
+ assert(bitsetDiff.nextSetBit(49) === -1)
+ assert(bitsetDiff.nextSetBit(65) === -1)
+ }
+
+ test( "andNot len(bitsetX) > len(bitsetY)" ) {
+ val setBitsX = Seq( 0, 1, 3, 37, 38, 41, 85)
+ val setBitsY = Seq( 0, 2, 3, 37, 41, 48 )
+ val bitsetX = new BitSet(100)
+ setBitsX.foreach( i => bitsetX.set(i))
+ val bitsetY = new BitSet(60)
+ setBitsY.foreach( i => bitsetY.set(i))
+
+ val bitsetDiff = bitsetX.andNot( bitsetY )
+
+ assert(bitsetDiff.nextSetBit(0) === 1)
+ assert(bitsetDiff.nextSetBit(1) === 1)
+ assert(bitsetDiff.nextSetBit(2) === 38)
+ assert(bitsetDiff.nextSetBit(3) === 38)
+ assert(bitsetDiff.nextSetBit(38) === 38)
+ assert(bitsetDiff.nextSetBit(39) === 85)
+ assert(bitsetDiff.nextSetBit(85) === 85)
+ assert(bitsetDiff.nextSetBit(86) === -1)
+ }
}
diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md
index 2437a98672177..38becda0eae92 100644
--- a/dev/audit-release/README.md
+++ b/dev/audit-release/README.md
@@ -4,7 +4,7 @@ run them locally by setting appropriate environment variables.
```
$ cd sbt_app_core
-$ SCALA_VERSION=2.10.3 \
+$ SCALA_VERSION=2.10.4 \
SPARK_VERSION=1.0.0-SNAPSHOT \
SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \
sbt run
diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py
index 52c367d9b030d..fa2f02dfecc75 100755
--- a/dev/audit-release/audit_release.py
+++ b/dev/audit-release/audit_release.py
@@ -35,7 +35,7 @@
RELEASE_KEY = "9E4FE3AF"
RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/"
RELEASE_VERSION = "1.0.0"
-SCALA_VERSION = "2.10.3"
+SCALA_VERSION = "2.10.4"
SCALA_BINARY_VERSION = "2.10"
##
diff --git a/dev/audit-release/maven_app_core/pom.xml b/dev/audit-release/maven_app_core/pom.xml
index 0b837c01751fe..76a381f8e17e0 100644
--- a/dev/audit-release/maven_app_core/pom.xml
+++ b/dev/audit-release/maven_app_core/pom.xml
@@ -49,7 +49,7 @@
maven-compiler-plugin
- 2.3.2
+ 3.1
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 995106f111443..bf1c5d7953bd2 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -49,14 +49,14 @@ mvn -DskipTests \
-Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Pyarn -Pspark-ganglia-lgpl \
+ -Pyarn -Phive -Pspark-ganglia-lgpl\
-Dtag=$GIT_TAG -DautoVersionSubmodules=true \
--batch-mode release:prepare
mvn -DskipTests \
-Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Pyarn -Pspark-ganglia-lgpl\
+ -Pyarn -Phive -Pspark-ganglia-lgpl\
release:perform
rm -rf spark
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index e8f78fc5f231a..7a61943e94814 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -87,11 +87,20 @@ def merge_pr(pr_num, target_ref):
run_cmd("git fetch %s %s:%s" % (PUSH_REMOTE_NAME, target_ref, target_branch_name))
run_cmd("git checkout %s" % target_branch_name)
- run_cmd(['git', 'merge', pr_branch_name, '--squash'])
+ had_conflicts = False
+ try:
+ run_cmd(['git', 'merge', pr_branch_name, '--squash'])
+ except Exception as e:
+ msg = "Error merging: %s\nWould you like to manually fix-up this merge?" % e
+ continue_maybe(msg)
+ msg = "Okay, please fix any conflicts and 'git add' conflicting files... Finished?"
+ continue_maybe(msg)
+ had_conflicts = True
commit_authors = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name,
'--pretty=format:%an <%ae>']).split("\n")
- distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True)
+ distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x),
+ reverse=True)
primary_author = distinct_authors[0]
commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name,
'--pretty=format:%h [%an] %s']).split("\n\n")
@@ -105,6 +114,13 @@ def merge_pr(pr_num, target_ref):
merge_message_flags += ["-m", authors]
+ if had_conflicts:
+ committer_name = run_cmd("git config --get user.name").strip()
+ committer_email = run_cmd("git config --get user.email").strip()
+ message = "This patch had conflicts when merged, resolved by\nCommitter: %s <%s>" % (
+ committer_name, committer_email)
+ merge_message_flags += ["-m", message]
+
# The string "Closes #%s" string is required for GitHub to correctly close the PR
merge_message_flags += ["-m",
"Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)]
@@ -186,8 +202,10 @@ def maybe_cherry_pick(pr_num, merge_hash, default_branch):
maybe_cherry_pick(pr_num, merge_hash, latest_branch)
sys.exit(0)
-if bool(pr["mergeable"]) == False:
- fail("Pull request %s is not mergeable in its current form" % pr_num)
+if not bool(pr["mergeable"]):
+ msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \
+ "Continue? (experts only!)"
+ continue_maybe(msg)
print ("\n=== Pull Request #%s ===" % pr_num)
print("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % (
diff --git a/dev/run-tests b/dev/run-tests
index 6f115d2abd5b0..6ad674a2ba127 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -26,31 +26,31 @@ rm -rf ./work
# Fail fast
set -e
-
+set -o pipefail
if test -x "$JAVA_HOME/bin/java"; then
declare java_cmd="$JAVA_HOME/bin/java"
else
declare java_cmd=java
fi
-
JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q')
[ "$JAVA_VERSION" -ge 18 ] && echo "" || echo "[Warn] Java 8 tests will not run because JDK version is < 1.8."
echo "========================================================================="
echo "Running Apache RAT checks"
echo "========================================================================="
-
dev/check-license
echo "========================================================================="
echo "Running Scala style checks"
echo "========================================================================="
-sbt/sbt clean scalastyle
+dev/scalastyle
echo "========================================================================="
echo "Running Spark unit tests"
echo "========================================================================="
-sbt/sbt assembly test
+# echo "q" is needed because sbt on encountering a build file with failure (either resolution or compilation)
+# prompts the user for input either q, r, etc to quit or retry. This echo is there to make it not block.
+echo -e "q\n" | sbt/sbt assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
echo "========================================================================="
echo "Running PySpark tests"
@@ -64,5 +64,4 @@ echo "========================================================================="
echo "Detecting binary incompatibilites with MiMa"
echo "========================================================================="
./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
-sbt/sbt mima-report-binary-issues
-
+echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving"
diff --git a/dev/scalastyle b/dev/scalastyle
new file mode 100755
index 0000000000000..19955b9aaaad3
--- /dev/null
+++ b/dev/scalastyle
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+echo -e "q\n" | sbt/sbt clean scalastyle > scalastyle.txt
+ERRORS=$(cat scalastyle.txt | grep -e "\")
+if test ! -z "$ERRORS"; then
+ echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS"
+ exit 1
+else
+ echo -e "Scalastyle checks passed.\n"
+fi
diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile
index e543db6143e4d..5956d59130fbf 100644
--- a/docker/spark-test/base/Dockerfile
+++ b/docker/spark-test/base/Dockerfile
@@ -25,7 +25,7 @@ RUN apt-get update
# install a few other useful packages plus Open Jdk 7
RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server
-ENV SCALA_VERSION 2.10.3
+ENV SCALA_VERSION 2.10.4
ENV CDH_VERSION cdh4
ENV SCALA_HOME /opt/scala-$SCALA_VERSION
ENV SPARK_HOME /opt/spark
diff --git a/docs/_config.yml b/docs/_config.yml
index aa5a5adbc1743..d585b8c5ea763 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -6,7 +6,7 @@ markdown: kramdown
SPARK_VERSION: 1.0.0-SNAPSHOT
SPARK_VERSION_SHORT: 1.0.0
SCALA_BINARY_VERSION: "2.10"
-SCALA_VERSION: "2.10.3"
+SCALA_VERSION: "2.10.4"
MESOS_VERSION: 0.13.0
SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net
SPARK_GITHUB_URL: https://github.com/apache/spark
diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md
index 730a6e7932564..9cebaf12283fc 100644
--- a/docs/building-with-maven.md
+++ b/docs/building-with-maven.md
@@ -6,7 +6,7 @@ title: Building Spark with Maven
* This will become a table of contents (this text will be scraped).
{:toc}
-Building Spark using Maven Requires Maven 3 (the build process is tested with Maven 3.0.4) and Java 1.6 or newer.
+Building Spark using Maven requires Maven 3.0.4 or newer and Java 1.6 or newer.
## Setting up Maven's Memory Usage ##
diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md
index a555a7b5023e3..7f75ea44e4cea 100644
--- a/docs/cluster-overview.md
+++ b/docs/cluster-overview.md
@@ -50,6 +50,50 @@ The system currently supports three cluster managers:
In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone
cluster on Amazon EC2.
+# Launching Applications
+
+The recommended way to launch a compiled Spark application is through the spark-submit script (located in the
+bin directory), which takes care of setting up the classpath with Spark and its dependencies, as well as
+provides a layer over the different cluster managers and deploy modes that Spark supports. It's usage is
+
+ spark-submit `` ``
+
+Where options are any of:
+
+- **\--class** - The main class to run.
+- **\--master** - The URL of the cluster manager master, e.g. spark://host:port, mesos://host:port, yarn,
+ or local.
+- **\--deploy-mode** - "client" to run the driver in the client process or "cluster" to run the driver in
+ a process on the cluster. For Mesos, only "client" is supported.
+- **\--executor-memory** - Memory per executor (e.g. 1000M, 2G).
+- **\--executor-cores** - Number of cores per executor. (Default: 2)
+- **\--driver-memory** - Memory for driver (e.g. 1000M, 2G)
+- **\--name** - Name of the application.
+- **\--arg** - Argument to be passed to the application's main class. This option can be specified
+ multiple times to pass multiple arguments.
+- **\--jars** - A comma-separated list of local jars to include on the driver classpath and that
+ SparkContext.addJar will work with. Doesn't work on standalone with 'cluster' deploy mode.
+
+The following currently only work for Spark standalone with cluster deploy mode:
+
+- **\--driver-cores** - Cores for driver (Default: 1).
+- **\--supervise** - If given, restarts the driver on failure.
+
+The following only works for Spark standalone and Mesos only:
+
+- **\--total-executor-cores** - Total cores for all executors.
+
+The following currently only work for YARN:
+
+- **\--queue** - The YARN queue to place the application in.
+- **\--files** - Comma separated list of files to be placed in the working dir of each executor.
+- **\--archives** - Comma separated list of archives to be extracted into the working dir of each
+ executor.
+- **\--num-executors** - Number of executors (Default: 2).
+
+The master and deploy mode can also be set with the MASTER and DEPLOY_MODE environment variables.
+Values for these options passed via command line will override the environment variables.
+
# Shipping Code to the Cluster
The recommended way to ship your code to the cluster is to pass it through SparkContext's constructor,
@@ -102,6 +146,12 @@ The following table summarizes terms you'll see used to refer to cluster concept
Cluster manager |
An external service for acquiring resources on the cluster (e.g. standalone manager, Mesos, YARN) |
+
+ Deploy mode |
+ Distinguishes where the driver process runs. In "cluster" mode, the framework launches
+ the driver inside of the cluster. In "client" mode, the submitter launches the driver
+ outside of the cluster. |
+
Worker node |
Any node that can run application code in the cluster |
diff --git a/docs/configuration.md b/docs/configuration.md
index 1ff0150567255..57bda20edcdf1 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -122,6 +122,21 @@ Apart from these, the following properties are also available, and may be useful
spark.storage.memoryFraction
.
+
+ spark.tachyonStore.baseDir |
+ System.getProperty("java.io.tmpdir") |
+
+ Directories of the Tachyon File System that store RDDs. The Tachyon file system's URL is set by spark.tachyonStore.url .
+ It can also be a comma-separated list of multiple directories on Tachyon file system.
+ |
+
+
+ spark.tachyonStore.url |
+ tachyon://localhost:19998 |
+
+ The URL of the underlying Tachyon file system in the TachyonStore.
+ |
+
spark.mesos.coarse |
false |
@@ -161,13 +176,13 @@ Apart from these, the following properties are also available, and may be useful
spark.ui.acls.enable |
false |
- Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has
+ Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has
access permissions to view the web ui. See spark.ui.view.acls for more details.
Also note this requires the user to be known, if the user comes across as null no checks
are done. Filters can be used to authenticate and set the user.
|
-
+
spark.ui.view.acls |
Empty |
@@ -276,10 +291,10 @@ Apart from these, the following properties are also available, and may be useful
| spark.serializer.objectStreamReset |
10000 |
- When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
- objects to prevent writing redundant data, however that stops garbage collection of those
- objects. By calling 'reset' you flush that info from the serializer, and allow old
- objects to be collected. To turn off this periodic reset set it to a value of <= 0.
+ When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
+ objects to prevent writing redundant data, however that stops garbage collection of those
+ objects. By calling 'reset' you flush that info from the serializer, and allow old
+ objects to be collected. To turn off this periodic reset set it to a value of <= 0.
By default it will reset the serializer every 10,000 objects.
|
@@ -333,6 +348,32 @@ Apart from these, the following properties are also available, and may be useful
receives no heartbeats.
+
+ spark.worker.cleanup.enabled |
+ true |
+
+ Enable periodic cleanup of worker / application directories. Note that this only affects standalone
+ mode, as YARN works differently.
+ |
+
+
+ spark.worker.cleanup.interval |
+ 1800 (30 minutes) |
+
+ Controls the interval, in seconds, at which the worker cleans up old application work dirs
+ on the local machine.
+ |
+
+
+ spark.worker.cleanup.appDataTtl |
+ 7 * 24 * 3600 (7 days) |
+
+ The number of seconds to retain application work directories on each worker. This is a Time To Live
+ and should depend on the amount of available disk space you have. Application logs and jars are
+ downloaded to each application work dir. Over time, the work dirs can quickly fill up disk space,
+ especially if you run jobs very frequently.
+ |
+
spark.akka.frameSize |
10 |
@@ -375,7 +416,7 @@ Apart from these, the following properties are also available, and may be useful
spark.akka.heartbeat.interval |
1000 |
- This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using failure detector can be, a sensistive failure detector can help evict rogue executors really quick. However this is usually not the case as gc pauses and network lags are expected in a real spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those.
+ This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using failure detector can be, a sensistive failure detector can help evict rogue executors really quick. However this is usually not the case as gc pauses and network lags are expected in a real spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those.
|
@@ -430,7 +471,7 @@ Apart from these, the following properties are also available, and may be useful
spark.broadcast.blockSize |
4096 |
- Size of each piece of a block in kilobytes for TorrentBroadcastFactory .
+ Size of each piece of a block in kilobytes for TorrentBroadcastFactory .
Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit.
|
@@ -555,7 +596,7 @@ Apart from these, the following properties are also available, and may be useful
the driver.
-
+
spark.authenticate |
false |
@@ -563,7 +604,7 @@ Apart from these, the following properties are also available, and may be useful
running on Yarn.
|
-
+
spark.authenticate.secret |
None |
@@ -571,12 +612,12 @@ Apart from these, the following properties are also available, and may be useful
not running on Yarn and authentication is enabled.
|
-
+
spark.core.connection.auth.wait.timeout |
30 |
Number of seconds for the connection to wait for authentication to occur before timing
- out and giving up.
+ out and giving up.
|
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index 203d235bf9663..a5e0cc50809cf 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -38,6 +38,5 @@ depends on native Fortran routines. You may need to install the
if it is not already present on your nodes. MLlib will throw a linking error if it cannot
detect these libraries automatically.
-To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.7 or newer
-and Python 2.7.
+To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.7 or newer.
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index cbe7d820b455e..888631e7025b0 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -82,15 +82,16 @@ The Python shell can be used explore data interactively and is a simple way to l
>>> help(pyspark) # Show all pyspark functions
{% endhighlight %}
-By default, the `bin/pyspark` shell creates SparkContext that runs applications locally on a single core.
-To connect to a non-local cluster, or use multiple cores, set the `MASTER` environment variable.
+By default, the `bin/pyspark` shell creates SparkContext that runs applications locally on all of
+your machine's logical cores.
+To connect to a non-local cluster, or to specify a number of cores, set the `MASTER` environment variable.
For example, to use the `bin/pyspark` shell with a [standalone Spark cluster](spark-standalone.html):
{% highlight bash %}
$ MASTER=spark://IP:PORT ./bin/pyspark
{% endhighlight %}
-Or, to use four cores on the local machine:
+Or, to use exactly four cores on the local machine:
{% highlight bash %}
$ MASTER=local[4] ./bin/pyspark
@@ -152,7 +153,7 @@ Many of the methods also contain [doctests](http://docs.python.org/2/library/doc
# Libraries
[MLlib](mllib-guide.html) is also available in PySpark. To use it, you'll need
-[NumPy](http://www.numpy.org) version 1.7 or newer, and Python 2.7. The [MLlib guide](mllib-guide.html) contains
+[NumPy](http://www.numpy.org) version 1.7 or newer. The [MLlib guide](mllib-guide.html) contains
some example applications.
# Where to Go from Here
diff --git a/docs/quick-start.md b/docs/quick-start.md
index 13df6beea16e8..60e8b1ba0eb46 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -124,7 +124,7 @@ object SimpleApp {
}
{% endhighlight %}
-This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the proogram. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the application, the directory where Spark is installed, and a name for the jar file containing the application's code. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes.
+This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the program. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the application, the directory where Spark is installed, and a name for the jar file containing the application's code. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes.
This file depends on the Spark API, so we'll also include an sbt configuration file, `simple.sbt` which explains that Spark is a dependency. This file also adds a repository that Spark depends on:
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 2e9dec4856ee9..982514391ac00 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -48,10 +48,12 @@ System Properties:
Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the Hadoop cluster.
These configs are used to connect to the cluster, write to the dfs, and connect to the YARN ResourceManager.
-There are two scheduler modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN.
+There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN.
Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster".
+The spark-submit script described in the [cluster mode overview](cluster-overview.html) provides the most straightforward way to submit a compiled Spark application to YARN in either deploy mode. For info on the lower-level invocations it uses, read ahead. For running spark-shell against YARN, skip down to the yarn-client section.
+
## Launching a Spark application with yarn-cluster mode.
The command to launch the Spark application on the cluster is as follows:
@@ -59,7 +61,7 @@ The command to launch the Spark application on the cluster is as follows:
SPARK_JAR= ./bin/spark-class org.apache.spark.deploy.yarn.Client \
--jar \
--class \
- --args \
+ --arg \
--num-executors \
--driver-memory \
--executor-memory \
@@ -70,7 +72,7 @@ The command to launch the Spark application on the cluster is as follows:
--files \
--archives
-For example:
+To pass multiple arguments the "arg" option can be specified multiple times. For example:
# Build the Spark assembly JAR and the Spark examples JAR
$ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly
@@ -83,7 +85,8 @@ For example:
./bin/spark-class org.apache.spark.deploy.yarn.Client \
--jar examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
--class org.apache.spark.examples.SparkPi \
- --args yarn-cluster \
+ --arg yarn-cluster \
+ --arg 5 \
--num-executors 3 \
--driver-memory 4g \
--executor-memory 2g \
@@ -121,7 +124,7 @@ or
MASTER=yarn-client ./bin/spark-shell
-## Viewing logs
+# Viewing logs
In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the yarn.log-aggregation-enable config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command.
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 99412733d4268..a07cd2e0a32a2 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -23,7 +23,7 @@ To write a Spark application, you need to add a dependency on Spark. If you use
groupId = org.apache.spark
artifactId = spark-core_{{site.SCALA_BINARY_VERSION}}
- version = {{site.SPARK_VERSION}}
+ version = {{site.SPARK_VERSION}}
In addition, if you wish to access an HDFS cluster, you need to add a dependency on `hadoop-client` for your version of HDFS:
@@ -54,7 +54,7 @@ object for more advanced configuration.
The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later.
-In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `bin/spark-shell` on four cores, use
+In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `bin/spark-shell` on exactly four cores, use
{% highlight bash %}
$ MASTER=local[4] ./bin/spark-shell
@@ -73,18 +73,19 @@ The master URL passed to Spark can be in one of the following formats:
Master URL | Meaning |
local | Run Spark locally with one worker thread (i.e. no parallelism at all). |
- local[K] | Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine).
+ |
local[K] | Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine).
+ |
local[*] | Run Spark locally with as many worker threads as logical cores on your machine. |
- spark://HOST:PORT | Connect to the given Spark standalone
- cluster master. The port must be whichever one your master is configured to use, which is 7077 by default.
+ |
spark://HOST:PORT | Connect to the given Spark standalone
+ cluster master. The port must be whichever one your master is configured to use, which is 7077 by default.
|
- mesos://HOST:PORT | Connect to the given Mesos cluster.
- The host parameter is the hostname of the Mesos master. The port must be whichever one the master is configured to use,
- which is 5050 by default.
+ |
mesos://HOST:PORT | Connect to the given Mesos cluster.
+ The host parameter is the hostname of the Mesos master. The port must be whichever one the master is configured to use,
+ which is 5050 by default.
|
-If no master URL is specified, the spark shell defaults to "local".
+If no master URL is specified, the spark shell defaults to "local[*]".
For running on YARN, Spark launches an instance of the standalone deploy cluster within YARN; see [running on YARN](running-on-yarn.html) for details.
@@ -265,11 +266,25 @@ A complete list of actions is available in the [RDD API doc](api/core/index.html
## RDD Persistence
-One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory across operations. When you persist an RDD, each node stores any slices of it that it computes in memory and reuses them in other actions on that dataset (or datasets derived from it). This allows future actions to be much faster (often by more than 10x). Caching is a key tool for building iterative algorithms with Spark and for interactive use from the interpreter.
-
-You can mark an RDD to be persisted using the `persist()` or `cache()` methods on it. The first time it is computed in an action, it will be kept in memory on the nodes. The cache is fault-tolerant -- if any partition of an RDD is lost, it will automatically be recomputed using the transformations that originally created it.
-
-In addition, each RDD can be stored using a different *storage level*, allowing you, for example, to persist the dataset on disk, or persist it in memory but as serialized Java objects (to save space), or even replicate it across nodes. These levels are chosen by passing a [`org.apache.spark.storage.StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel) object to `persist()`. The `cache()` method is a shorthand for using the default storage level, which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The complete set of available storage levels is:
+One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory
+across operations. When you persist an RDD, each node stores any slices of it that it computes in
+memory and reuses them in other actions on that dataset (or datasets derived from it). This allows
+future actions to be much faster (often by more than 10x). Caching is a key tool for building
+iterative algorithms with Spark and for interactive use from the interpreter.
+
+You can mark an RDD to be persisted using the `persist()` or `cache()` methods on it. The first time
+it is computed in an action, it will be kept in memory on the nodes. The cache is fault-tolerant --
+if any partition of an RDD is lost, it will automatically be recomputed using the transformations
+that originally created it.
+
+In addition, each RDD can be stored using a different *storage level*, allowing you, for example, to
+persist the dataset on disk, or persist it in memory but as serialized Java objects (to save space),
+or replicate it across nodes, or store the data in off-heap memory in [Tachyon](http://tachyon-project.org/).
+These levels are chosen by passing a
+[`org.apache.spark.storage.StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel)
+object to `persist()`. The `cache()` method is a shorthand for using the default storage level,
+which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The complete set of
+available storage levels is:
Storage Level | Meaning |
@@ -292,8 +307,16 @@ In addition, each RDD can be stored using a different *storage level*, allowing
MEMORY_AND_DISK_SER |
- Similar to MEMORY_ONLY_SER, but spill partitions that don't fit in memory to disk instead of recomputing them
- on the fly each time they're needed. |
+ Similar to MEMORY_ONLY_SER, but spill partitions that don't fit in memory to disk instead of
+ recomputing them on the fly each time they're needed. |
+
+
+ OFF_HEAP |
+ Store RDD in a serialized format in Tachyon.
+ This is generally more space-efficient than deserialized objects, especially when using a
+ fast serializer, but more CPU-intensive to read.
+ This also significantly reduces the overheads of GC.
+ |
DISK_ONLY |
@@ -307,30 +330,59 @@ In addition, each RDD can be stored using a different *storage level*, allowing
### Which Storage Level to Choose?
-Spark's storage levels are meant to provide different tradeoffs between memory usage and CPU efficiency.
-We recommend going through the following process to select one:
-
-* If your RDDs fit comfortably with the default storage level (`MEMORY_ONLY`), leave them that way. This is the most
- CPU-efficient option, allowing operations on the RDDs to run as fast as possible.
-* If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to make the objects
- much more space-efficient, but still reasonably fast to access.
-* Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large
- amount of the data. Otherwise, recomputing a partition is about as fast as reading it from disk.
-* Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web
- application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones
- let you continue running tasks on the RDD without waiting to recompute a lost partition.
-
-If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel$) singleton object.
+Spark's storage levels are meant to provide different trade-offs between memory usage and CPU
+efficiency. It allows uses to choose memory, disk, or Tachyon for storing data. We recommend going
+through the following process to select one:
+
+* If your RDDs fit comfortably with the default storage level (`MEMORY_ONLY`), leave them that way.
+ This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible.
+
+* If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to
+make the objects much more space-efficient, but still reasonably fast to access. You can also use
+`OFF_HEAP` mode to store the data off the heap in [Tachyon](http://tachyon-project.org/). This will
+significantly reduce JVM GC overhead.
+
+* Don't spill to disk unless the functions that computed your datasets are expensive, or they filter
+a large amount of the data. Otherwise, recomputing a partition is about as fast as reading it from
+disk.
+
+* Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve
+requests from a web application). *All* the storage levels provide full fault tolerance by
+recomputing lost data, but the replicated ones let you continue running tasks on the RDD without
+waiting to recompute a lost partition.
+
+If you want to define your own storage level (say, with replication factor of 3 instead of 2), then
+use the function factor method `apply()` of the
+[`StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel$) singleton object.
+
+Spark has a block manager inside the Executors that let you chose memory, disk, or off-heap. The
+latter is for storing RDDs off-heap outside the Executor JVM on top of the memory management system
+[Tachyon](http://tachyon-project.org/). This mode has the following advantages:
+
+* Cached data will not be lost if individual executors crash.
+* Executors can have a smaller memory footprint, allowing you to run more executors on the same
+machine as the bulk of the memory will be inside Tachyon.
+* Reduced GC overhead since data is stored in Tachyon.
# Shared Variables
-Normally, when a function passed to a Spark operation (such as `map` or `reduce`) is executed on a remote cluster node, it works on separate copies of all the variables used in the function. These variables are copied to each machine, and no updates to the variables on the remote machine are propagated back to the driver program. Supporting general, read-write shared variables across tasks would be inefficient. However, Spark does provide two limited types of *shared variables* for two common usage patterns: broadcast variables and accumulators.
+Normally, when a function passed to a Spark operation (such as `map` or `reduce`) is executed on a
+remote cluster node, it works on separate copies of all the variables used in the function. These
+variables are copied to each machine, and no updates to the variables on the remote machine are
+propagated back to the driver program. Supporting general, read-write shared variables across tasks
+would be inefficient. However, Spark does provide two limited types of *shared variables* for two
+common usage patterns: broadcast variables and accumulators.
## Broadcast Variables
-Broadcast variables allow the programmer to keep a read-only variable cached on each machine rather than shipping a copy of it with tasks. They can be used, for example, to give every node a copy of a large input dataset in an efficient manner. Spark also attempts to distribute broadcast variables using efficient broadcast algorithms to reduce communication cost.
+Broadcast variables allow the programmer to keep a read-only variable cached on each machine rather
+than shipping a copy of it with tasks. They can be used, for example, to give every node a copy of a
+large input dataset in an efficient manner. Spark also attempts to distribute broadcast variables
+using efficient broadcast algorithms to reduce communication cost.
-Broadcast variables are created from a variable `v` by calling `SparkContext.broadcast(v)`. The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the `value` method. The interpreter session below shows this:
+Broadcast variables are created from a variable `v` by calling `SparkContext.broadcast(v)`. The
+broadcast variable is a wrapper around `v`, and its value can be accessed by calling the `value`
+method. The interpreter session below shows this:
{% highlight scala %}
scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
@@ -340,13 +392,21 @@ scala> broadcastVar.value
res0: Array[Int] = Array(1, 2, 3)
{% endhighlight %}
-After the broadcast variable is created, it should be used instead of the value `v` in any functions run on the cluster so that `v` is not shipped to the nodes more than once. In addition, the object `v` should not be modified after it is broadcast in order to ensure that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped to a new node later).
+After the broadcast variable is created, it should be used instead of the value `v` in any functions
+run on the cluster so that `v` is not shipped to the nodes more than once. In addition, the object
+`v` should not be modified after it is broadcast in order to ensure that all nodes get the same
+value of the broadcast variable (e.g. if the variable is shipped to a new node later).
## Accumulators
-Accumulators are variables that are only "added" to through an associative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric value types and standard mutable collections, and programmers can add support for new types.
+Accumulators are variables that are only "added" to through an associative operation and can
+therefore be efficiently supported in parallel. They can be used to implement counters (as in
+MapReduce) or sums. Spark natively supports accumulators of numeric value types and standard mutable
+collections, and programmers can add support for new types.
-An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks running on the cluster can then add to it using the `+=` operator. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method.
+An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks
+running on the cluster can then add to it using the `+=` operator. However, they cannot read its
+value. Only the driver program can read the accumulator's value, using its `value` method.
The interpreter session below shows an accumulator being used to add up the elements of an array:
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 51fb3a4f7f8c5..7e4eea323aa63 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -146,10 +146,13 @@ automatically set MASTER from the `SPARK_MASTER_IP` and `SPARK_MASTER_PORT` vari
You can also pass an option `-c ` to control the number of cores that spark-shell uses on the cluster.
-# Launching Applications Inside the Cluster
+# Launching Compiled Spark Applications
-You may also run your application entirely inside of the cluster by submitting your application driver using the submission client. The syntax for submitting applications is as follows:
+Spark supports two deploy modes. Spark applications may run with the driver inside the client process or entirely inside the cluster.
+The spark-submit script described in the [cluster mode overview](cluster-overview.html) provides the most straightforward way to submit a compiled Spark application to the cluster in either deploy mode. For info on the lower-level invocations used to launch an app inside the cluster, read ahead.
+
+## Launching Applications Inside the Cluster
./bin/spark-class org.apache.spark.deploy.Client launch
[client-options] \
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index b6f21a5dc62c3..a59393e1424de 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -8,6 +8,10 @@ title: Spark SQL Programming Guide
{:toc}
# Overview
+
+
+
+
Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using
Spark. At the core of this component is a new type of RDD,
[SchemaRDD](api/sql/core/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed
@@ -18,11 +22,27 @@ file, or by running HiveQL against data stored in [Apache Hive](http://hive.apac
**All of the examples on this page use sample data included in the Spark distribution and can be run in the spark-shell.**
+
+
+
+Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using
+Spark. At the core of this component is a new type of RDD,
+[JavaSchemaRDD](api/sql/core/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed
+[Row](api/sql/catalyst/index.html#org.apache.spark.sql.api.java.Row) objects along with
+a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table
+in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, parquet
+file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
+
+
+
***************************************************************************************************
# Getting Started
-The entry point into all relational functionallity in Spark is the
+
+
+
+The entry point into all relational functionality in Spark is the
[SQLContext](api/sql/core/index.html#org.apache.spark.sql.SQLContext) class, or one of its
decendents. To create a basic SQLContext, all you need is a SparkContext.
@@ -34,8 +54,30 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext._
{% endhighlight %}
+
+
+
+
+The entry point into all relational functionality in Spark is the
+[JavaSQLContext](api/sql/core/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one
+of its decendents. To create a basic JavaSQLContext, all you need is a JavaSparkContext.
+
+{% highlight java %}
+JavaSparkContext ctx = ...; // An existing JavaSparkContext.
+JavaSQLContext sqlCtx = new org.apache.spark.sql.api.java.JavaSQLContext(ctx);
+{% endhighlight %}
+
+
+
+
+
## Running SQL on RDDs
-One type of table that is supported by Spark SQL is an RDD of Scala case classetees. The case class
+
+
+
+
+
+One type of table that is supported by Spark SQL is an RDD of Scala case classes. The case class
defines the schema of the table. The names of the arguments to the case class are read using
reflection and become the names of the columns. Case classes can also be nested or contain complex
types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be
@@ -60,7 +102,83 @@ val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
{% endhighlight %}
-**Note that Spark SQL currently uses a very basic SQL parser, and the keywords are case sensitive.**
+
+
+
+
+One type of table that is supported by Spark SQL is an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly). The BeanInfo
+defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain
+nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a
+class that implements Serializable and has getters and setters for all of its fields.
+
+{% highlight java %}
+
+public static class Person implements Serializable {
+ private String name;
+ private int age;
+
+ String getName() {
+ return name;
+ }
+
+ void setName(String name) {
+ this.name = name;
+ }
+
+ int getAge() {
+ return age;
+ }
+
+ void setAge(int age) {
+ this.age = age;
+ }
+}
+
+{% endhighlight %}
+
+
+A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object
+for the JavaBean.
+
+{% highlight java %}
+JavaSQLContext ctx = new org.apache.spark.sql.api.java.JavaSQLContext(sc)
+
+// Load a text file and convert each line to a JavaBean.
+JavaRDD
people = ctx.textFile("examples/src/main/resources/people.txt").map(
+ new Function() {
+ public Person call(String line) throws Exception {
+ String[] parts = line.split(",");
+
+ Person person = new Person();
+ person.setName(parts[0]);
+ person.setAge(Integer.parseInt(parts[1].trim()));
+
+ return person;
+ }
+ });
+
+// Apply a schema to an RDD of JavaBeans and register it as a table.
+JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class);
+schemaPeople.registerAsTable("people");
+
+// SQL can be run over RDDs that have been registered as tables.
+JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
+
+// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+// The columns of a row in the result can be accessed by ordinal.
+List teenagerNames = teenagers.map(new Function() {
+ public String call(Row row) {
+ return "Name: " + row.getString(0);
+ }
+}).collect();
+
+{% endhighlight %}
+
+
+
+
+
+**Note that Spark SQL currently uses a very basic SQL parser.**
Users that want a more complete dialect of SQL should look at the HiveQL support provided by
`HiveContext`.
@@ -70,17 +188,21 @@ Parquet is a columnar format that is supported by many other data processing sys
provides support for both reading and writing parquet files that automatically preserves the schema
of the original data. Using the data from the above example:
+
+
+
+
{% highlight scala %}
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext._
-val people: RDD[Person] // An RDD of case class objects, from the previous example.
+val people: RDD[Person] = ... // An RDD of case class objects, from the previous example.
// The RDD is implicitly converted to a SchemaRDD, allowing it to be stored using parquet.
people.saveAsParquetFile("people.parquet")
// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
-// The result of loading a parquet file is also a SchemaRDD.
+// The result of loading a parquet file is also a JavaSchemaRDD.
val parquetFile = sqlContext.parquetFile("people.parquet")
//Parquet files can also be registered as tables and then used in SQL statements.
@@ -89,15 +211,43 @@ val teenagers = sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"
teenagers.collect().foreach(println)
{% endhighlight %}
+
+
+
+
+{% highlight java %}
+
+JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example.
+
+// JavaSchemaRDDs can be saved as parquet files, maintaining the schema information.
+schemaPeople.saveAsParquetFile("people.parquet");
+
+// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
+// The result of loading a parquet file is also a JavaSchemaRDD.
+JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet");
+
+//Parquet files can also be registered as tables and then used in SQL statements.
+parquetFile.registerAsTable("parquetFile");
+JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
+
+
+{% endhighlight %}
+
+
+
+
+
## Writing Language-Integrated Relational Queries
+**Language-Integrated queries are currently only supported in Scala.**
+
Spark SQL also supports a domain specific language for writing queries. Once again,
using the data from the above examples:
{% highlight scala %}
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext._
-val people: RDD[Person] // An RDD of case class objects, from the first example.
+val people: RDD[Person] = ... // An RDD of case class objects, from the first example.
// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19'
val teenagers = people.where('age >= 10).where('age <= 19).select('name)
@@ -114,14 +264,17 @@ evaluated by the SQL execution engine. A full list of the functions supported c
Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/).
However, since Hive has a large number of dependencies, it is not included in the default Spark assembly.
-In order to use Hive you must first run '`sbt/sbt hive/assembly`'. This command builds a new assembly
-jar that includes Hive. When this jar is present, Spark will use the Hive
-assembly instead of the normal Spark assembly. Note that this Hive assembly jar must also be present
+In order to use Hive you must first run '`SPARK_HIVE=true sbt/sbt assembly/assembly`' (or use `-Phive` for maven).
+This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present
on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries
(SerDes) in order to acccess data stored in Hive.
Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.
+
+
+
+
When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and
adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do
not have an existing Hive deployment can also experiment with the `LocalHiveContext`,
@@ -135,9 +288,34 @@ val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc)
// Importing the SQL context gives access to all the public SQL functions and implicit conversions.
import hiveContext._
-sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
-sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
+hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src")
// Queries are expressed in HiveQL
-sql("SELECT key, value FROM src").collect().foreach(println)
-{% endhighlight %}
\ No newline at end of file
+hql("FROM src SELECT key, value").collect().foreach(println)
+{% endhighlight %}
+
+
+
+
+
+When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and
+adds support for finding tables in in the MetaStore and writing queries using HiveQL. In addition to
+the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be
+expressed in HiveQL.
+
+{% highlight java %}
+JavaSparkContext ctx = ...; // An existing JavaSparkContext.
+JavaHiveContext hiveCtx = new org.apache.spark.sql.hive.api.java.HiveContext(ctx);
+
+hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)");
+hiveCtx.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src");
+
+// Queries are expressed in HiveQL.
+Row[] results = hiveCtx.hql("FROM src SELECT key, value").collect();
+
+{% endhighlight %}
+
+
+
+
diff --git a/examples/pom.xml b/examples/pom.xml
index a5569ff5e71f3..0b6212b5d1549 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -110,7 +110,7 @@
org.apache.hbase
hbase
- 0.94.6
+ ${hbase.version}
asm
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
new file mode 100644
index 0000000000000..e8e63d2745692
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.sql;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.VoidFunction;
+
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import org.apache.spark.sql.api.java.Row;
+
+public class JavaSparkSQL {
+ public static class Person implements Serializable {
+ private String name;
+ private int age;
+
+ String getName() {
+ return name;
+ }
+
+ void setName(String name) {
+ this.name = name;
+ }
+
+ int getAge() {
+ return age;
+ }
+
+ void setAge(int age) {
+ this.age = age;
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ JavaSparkContext ctx = new JavaSparkContext("local", "JavaSparkSQL",
+ System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaSparkSQL.class));
+ JavaSQLContext sqlCtx = new JavaSQLContext(ctx);
+
+ // Load a text file and convert each line to a Java Bean.
+ JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map(
+ new Function() {
+ public Person call(String line) throws Exception {
+ String[] parts = line.split(",");
+
+ Person person = new Person();
+ person.setName(parts[0]);
+ person.setAge(Integer.parseInt(parts[1].trim()));
+
+ return person;
+ }
+ });
+
+ // Apply a schema to an RDD of Java Beans and register it as a table.
+ JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class);
+ schemaPeople.registerAsTable("people");
+
+ // SQL can be run over RDDs that have been registered as tables.
+ JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+
+ // The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+ // The columns of a row in the result can be accessed by ordinal.
+ List teenagerNames = teenagers.map(new Function() {
+ public String call(Row row) {
+ return "Name: " + row.getString(0);
+ }
+ }).collect();
+
+ // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information.
+ schemaPeople.saveAsParquetFile("people.parquet");
+
+ // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
+ // The result of loading a parquet file is also a JavaSchemaRDD.
+ JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet");
+
+ //Parquet files can also be registered as tables and then used in SQL statements.
+ parquetFile.registerAsTable("parquetFile");
+ JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java
index 667c72f379e71..cd8879ff886e2 100644
--- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java
+++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.examples;
+import java.util.regex.Pattern;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -24,11 +25,9 @@
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
+import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import java.util.Arrays;
-import java.util.regex.Pattern;
-
/**
* Logistic regression based classification using ML Lib.
*/
@@ -47,14 +46,10 @@ public LabeledPoint call(String line) {
for (int i = 0; i < tok.length; ++i) {
x[i] = Double.parseDouble(tok[i]);
}
- return new LabeledPoint(y, x);
+ return new LabeledPoint(y, Vectors.dense(x));
}
}
- public static void printWeights(double[] a) {
- System.out.println(Arrays.toString(a));
- }
-
public static void main(String[] args) {
if (args.length != 4) {
System.err.println("Usage: JavaLR ");
@@ -80,8 +75,7 @@ public static void main(String[] args) {
LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(),
iterations, stepSize);
- System.out.print("Final w: ");
- printWeights(model.weights());
+ System.out.print("Final w: " + model.weights());
System.exit(0);
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
index c8ecbb8e41a86..0095cb8425456 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
@@ -53,7 +53,6 @@ object LocalALS {
for (i <- 0 until M; j <- 0 until U) {
r.set(i, j, blas.ddot(ms(i), us(j)))
}
- //println("R: " + r)
blas.daxpy(-1, targetR, r)
val sumSqs = r.aggregate(Functions.plus, Functions.square)
sqrt(sumSqs / (M * U))
diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
index 73b0e216cac98..1fdb324b89f3a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
@@ -61,7 +61,7 @@ object SimpleSkewedGroupByTest {
println("RESULT: " + pairs1.groupByKey(numReducers).count)
// Print how many keys each reducer got (for debugging)
- //println("RESULT: " + pairs1.groupByKey(numReducers)
+ // println("RESULT: " + pairs1.groupByKey(numReducers)
// .map{case (k,v) => (k, v.size)}
// .collectAsMap)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
index ce4b3c8451e00..f59ab7e7cc24a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
@@ -54,7 +54,6 @@ object SparkALS {
for (i <- 0 until M; j <- 0 until U) {
r.set(i, j, blas.ddot(ms(i), us(j)))
}
- //println("R: " + r)
blas.daxpy(-1, targetR, r)
val sumSqs = r.aggregate(Functions.plus, Functions.square)
sqrt(sumSqs / (M * U))
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index cf1fc3e808c76..e698b9bf376e1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -34,8 +34,6 @@ object SparkHdfsLR {
case class DataPoint(x: Vector, y: Double)
def parsePoint(line: String): DataPoint = {
- //val nums = line.split(' ').map(_.toDouble)
- //return DataPoint(new Vector(nums.slice(1, D+1)), nums(0))
val tok = new java.util.StringTokenizer(line, " ")
var y = tok.nextToken.toDouble
var x = new Array[Double](D)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
index e5a09ecec006f..d3babc3ed12c8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
@@ -18,8 +18,8 @@
package org.apache.spark.examples
import scala.math.random
+
import org.apache.spark._
-import SparkContext._
/** Computes an approximation to pi */
object SparkPi {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
new file mode 100644
index 0000000000000..53b303d658386
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples
+
+import java.util.Random
+import scala.math.exp
+import org.apache.spark.util.Vector
+import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler.InputFormatInfo
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Logistic regression based classification.
+ * This example uses Tachyon to persist rdds during computation.
+ */
+object SparkTachyonHdfsLR {
+ val D = 10 // Numer of dimensions
+ val rand = new Random(42)
+
+ case class DataPoint(x: Vector, y: Double)
+
+ def parsePoint(line: String): DataPoint = {
+ val tok = new java.util.StringTokenizer(line, " ")
+ var y = tok.nextToken.toDouble
+ var x = new Array[Double](D)
+ var i = 0
+ while (i < D) {
+ x(i) = tok.nextToken.toDouble; i += 1
+ }
+ DataPoint(new Vector(x), y)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: SparkTachyonHdfsLR ")
+ System.exit(1)
+ }
+ val inputPath = args(1)
+ val conf = SparkHadoopUtil.get.newConfiguration()
+ val sc = new SparkContext(args(0), "SparkTachyonHdfsLR",
+ System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass), Map(),
+ InputFormatInfo.computePreferredLocations(
+ Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath))
+ ))
+ val lines = sc.textFile(inputPath)
+ val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP)
+ val ITERATIONS = args(2).toInt
+
+ // Initialize w to a random value
+ var w = Vector(D, _ => 2 * rand.nextDouble - 1)
+ println("Initial w: " + w)
+
+ for (i <- 1 to ITERATIONS) {
+ println("On iteration " + i)
+ val gradient = points.map { p =>
+ (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x
+ }.reduce(_ + _)
+ w -= gradient
+ }
+
+ println("Final w: " + w)
+ System.exit(0)
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala
new file mode 100644
index 0000000000000..ce78f0876ed7c
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples
+
+import scala.math.random
+
+import org.apache.spark._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Computes an approximation to pi
+ * This example uses Tachyon to persist rdds during computation.
+ */
+object SparkTachyonPi {
+ def main(args: Array[String]) {
+ if (args.length == 0) {
+ System.err.println("Usage: SparkTachyonPi []")
+ System.exit(1)
+ }
+ val spark = new SparkContext(args(0), "SparkTachyonPi",
+ System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass))
+
+ val slices = if (args.length > 1) args(1).toInt else 2
+ val n = 100000 * slices
+
+ val rdd = spark.parallelize(1 to n, slices)
+ rdd.persist(StorageLevel.OFF_HEAP)
+ val count = rdd.map { i =>
+ val x = random * 2 - 1
+ val y = random * 2 - 1
+ if (x * x + y * y < 1) 1 else 0
+ }.reduce(_ + _)
+ println("Pi is roughly " + 4.0 * count / n)
+
+ spark.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala
index abcc1f04d4279..62329bde84481 100644
--- a/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala
@@ -33,20 +33,20 @@ object HiveFromSpark {
val hiveContext = new LocalHiveContext(sc)
import hiveContext._
- sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
- sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")
+ hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+ hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")
// Queries are expressed in HiveQL
println("Result of 'SELECT *': ")
- sql("SELECT * FROM src").collect.foreach(println)
+ hql("SELECT * FROM src").collect.foreach(println)
// Aggregation queries are also supported.
- val count = sql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
+ val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0)
println(s"COUNT(*): $count")
// The results of SQL queries are themselves RDDs and support all normal RDD functions. The
// items in the RDD are of type Row, which allows you to access each column by ordinal.
- val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")
+ val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key")
println("Result of RDD.map:")
val rddAsStrings = rddFromSql.map {
@@ -59,6 +59,6 @@ object HiveFromSpark {
// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
- sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
+ hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
index 62d3a52615584..a22e64ca3ce45 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
@@ -168,7 +168,7 @@ object ActorWordCount {
Props(new SampleActorReceiver[String]("akka.tcp://test@%s:%s/user/FeederActor".format(
host, port.toInt))), "SampleReceiver")
- //compute wordcount
+ // compute wordcount
lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print()
ssc.start()
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
index 35be7ffa1e872..35f8f885f8f0e 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
@@ -88,7 +88,7 @@ object ZeroMQWordCount {
def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator
- //For this stream, a zeroMQ publisher should be running.
+ // For this stream, a zeroMQ publisher should be running.
val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
index 0ac46c31c24c8..251f65fe4df9c 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
@@ -21,7 +21,7 @@ import java.net.ServerSocket
import java.io.PrintWriter
import util.Random
-/** Represents a page view on a website with associated dimension data.*/
+/** Represents a page view on a website with associated dimension data. */
class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int)
extends Serializable {
override def toString() : String = {
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index ce3ef47cfe4bc..34012b846e21e 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -127,7 +127,7 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol {
}
/** A NetworkReceiver which listens for events using the
- * Flume Avro interface.*/
+ * Flume Avro interface. */
private[streaming]
class FlumeReceiver(
host: String,
diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala
index 6acba25f44c0a..a538c38dc4d6f 100644
--- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala
+++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala
@@ -44,7 +44,7 @@ private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String,
case m: ZMQMessage =>
logDebug("Received message for:" + m.frame(0))
- //We ignore first frame for processing as it is the topic
+ // We ignore first frame for processing as it is the topic
val bytes = m.frames.tail
pushBlock(bytesToObjects(bytes))
diff --git a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
index cd37317da77de..d03d7774e8c80 100644
--- a/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
+++ b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.ganglia.GangliaReporter
import info.ganglia.gmetric4j.gmetric.GMetric
+import info.ganglia.gmetric4j.gmetric.GMetric.UDPAddressingMode
import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
@@ -33,10 +34,10 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry,
val GANGLIA_DEFAULT_PERIOD = 10
val GANGLIA_KEY_UNIT = "unit"
- val GANGLIA_DEFAULT_UNIT = TimeUnit.SECONDS
+ val GANGLIA_DEFAULT_UNIT: TimeUnit = TimeUnit.SECONDS
val GANGLIA_KEY_MODE = "mode"
- val GANGLIA_DEFAULT_MODE = GMetric.UDPAddressingMode.MULTICAST
+ val GANGLIA_DEFAULT_MODE: UDPAddressingMode = GMetric.UDPAddressingMode.MULTICAST
// TTL for multicast messages. If listeners are X hops away in network, must be at least X.
val GANGLIA_KEY_TTL = "ttl"
@@ -45,7 +46,7 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry,
val GANGLIA_KEY_HOST = "host"
val GANGLIA_KEY_PORT = "port"
- def propertyToOption(prop: String) = Option(property.getProperty(prop))
+ def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop))
if (!propertyToOption(GANGLIA_KEY_HOST).isDefined) {
throw new Exception("Ganglia sink requires 'host' property.")
@@ -58,11 +59,12 @@ class GangliaSink(val property: Properties, val registry: MetricRegistry,
val host = propertyToOption(GANGLIA_KEY_HOST).get
val port = propertyToOption(GANGLIA_KEY_PORT).get.toInt
val ttl = propertyToOption(GANGLIA_KEY_TTL).map(_.toInt).getOrElse(GANGLIA_DEFAULT_TTL)
- val mode = propertyToOption(GANGLIA_KEY_MODE)
+ val mode: UDPAddressingMode = propertyToOption(GANGLIA_KEY_MODE)
.map(u => GMetric.UDPAddressingMode.valueOf(u.toUpperCase)).getOrElse(GANGLIA_DEFAULT_MODE)
val pollPeriod = propertyToOption(GANGLIA_KEY_PERIOD).map(_.toInt)
.getOrElse(GANGLIA_DEFAULT_PERIOD)
- val pollUnit = propertyToOption(GANGLIA_KEY_UNIT).map(u => TimeUnit.valueOf(u.toUpperCase))
+ val pollUnit: TimeUnit = propertyToOption(GANGLIA_KEY_UNIT)
+ .map(u => TimeUnit.valueOf(u.toUpperCase))
.getOrElse(GANGLIA_DEFAULT_UNIT)
MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 5a5022916d234..b4c67ddcd8ca9 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -54,7 +54,7 @@
org.jblas
jblas
- 1.2.3
+ ${jblas.version}
org.eclipse.jetty
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index f2296a865e1b3..6d04bf790e3a5 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -45,7 +45,8 @@ class EdgeRDD[@specialized ED: ClassTag](
partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = {
- firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context).next._2.iterator
+ val p = firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context)
+ p.next._2.iterator.map(_.copy())
}
override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
index fea43c3b2bbf1..dfc6a801587d2 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala
@@ -27,12 +27,12 @@ class EdgeTriplet[VD, ED] extends Edge[ED] {
/**
* The source vertex attribute
*/
- var srcAttr: VD = _ //nullValue[VD]
+ var srcAttr: VD = _ // nullValue[VD]
/**
* The destination vertex attribute
*/
- var dstAttr: VD = _ //nullValue[VD]
+ var dstAttr: VD = _ // nullValue[VD]
/**
* Set the edge properties of this triplet.
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index 65a1a8c68f6d2..ef05623d7a0a1 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -419,5 +419,6 @@ object Graph {
* All the convenience operations are defined in the [[GraphOps]] class which may be
* shared across multiple graph implementations.
*/
- implicit def graphToGraphOps[VD: ClassTag, ED: ClassTag](g: Graph[VD, ED]) = g.ops
+ implicit def graphToGraphOps[VD: ClassTag, ED: ClassTag]
+ (g: Graph[VD, ED]): GraphOps[VD, ED] = g.ops
} // end of Graph object
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
index 57fa5eefd5e09..2e05f5d4e4969 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -56,6 +56,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
* Construct a new edge partition by applying the function f to all
* edges in this partition.
*
+ * Be careful not to keep references to the objects passed to `f`.
+ * To improve GC performance the same object is re-used for each call.
+ *
* @param f a function from an edge to a new attribute
* @tparam ED2 the type of the new attribute
* @return a new edge partition with the result of the function `f`
@@ -84,12 +87,12 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
* order of the edges returned by `EdgePartition.iterator` and
* should return attributes equal to the number of edges.
*
- * @param f a function from an edge to a new attribute
+ * @param iter an iterator for the new attribute values
* @tparam ED2 the type of the new attribute
- * @return a new edge partition with the result of the function `f`
- * applied to each edge
+ * @return a new edge partition with the attribute values replaced
*/
def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2] = {
+ // Faster than iter.toArray, because the expected size is known.
val newData = new Array[ED2](data.size)
var i = 0
while (iter.hasNext) {
@@ -188,6 +191,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
/**
* Get an iterator over the edges in this partition.
*
+ * Be careful not to keep references to the objects from this iterator.
+ * To improve GC performance the same object is re-used in `next()`.
+ *
* @return an iterator over edges in the partition
*/
def iterator = new Iterator[Edge[ED]] {
@@ -216,6 +222,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
/**
* Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
* cluster must start at position `index`.
+ *
+ * Be careful not to keep references to the objects from this iterator. To improve GC performance
+ * the same object is re-used in `next()`.
*/
private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] {
private[this] val edge = new Edge[ED]
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
index 886c250d7cffd..220a89d73d711 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
@@ -37,20 +37,15 @@ class EdgeTripletIterator[VD: ClassTag, ED: ClassTag](
// Current position in the array.
private var pos = 0
- // A triplet object that this iterator.next() call returns. We reuse this object to avoid
- // allocating too many temporary Java objects.
- private val triplet = new EdgeTriplet[VD, ED]
-
private val vmap = new PrimitiveKeyOpenHashMap[VertexId, VD](vidToIndex, vertexArray)
override def hasNext: Boolean = pos < edgePartition.size
override def next() = {
+ val triplet = new EdgeTriplet[VD, ED]
triplet.srcId = edgePartition.srcIds(pos)
- // assert(vmap.containsKey(e.src.id))
triplet.srcAttr = vmap(triplet.srcId)
triplet.dstId = edgePartition.dstIds(pos)
- // assert(vmap.containsKey(e.dst.id))
triplet.dstAttr = vmap(triplet.dstId)
triplet.attr = edgePartition.data(pos)
pos += 1
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 5e9be18990ba3..c2b510a31ee3f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -190,14 +190,14 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
new GraphImpl(vertices, newETable, routingTable, replicatedVertexView)
}
- //////////////////////////////////////////////////////////////////////////////////////////////////
+ // ///////////////////////////////////////////////////////////////////////////////////////////////
// Lower level transformation methods
- //////////////////////////////////////////////////////////////////////////////////////////////////
+ // ///////////////////////////////////////////////////////////////////////////////////////////////
override def mapReduceTriplets[A: ClassTag](
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
- activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) = {
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = {
ClosureCleaner.clean(mapFunc)
ClosureCleaner.clean(reduceFunc)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
index fe6fe76defdc5..9d4f3750cb8e4 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -45,7 +45,7 @@ class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T](
* @param data value to send
*/
private[graphx]
-class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T](
+class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T](
@transient var partition: PartitionID,
var data: T)
extends Product2[PartitionID, T] with Serializable {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
index 34a145e01818f..2f2c524df6394 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
@@ -298,7 +298,6 @@ abstract class ShuffleSerializationStream(s: OutputStream) extends Serialization
s.write(v.toInt)
}
- //def writeDouble(v: Double): Unit = writeUnsignedVarLong(java.lang.Double.doubleToLongBits(v))
def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v))
override def flush(): Unit = s.flush()
@@ -391,7 +390,6 @@ abstract class ShuffleDeserializationStream(s: InputStream) extends Deserializat
(s.read() & 0xFF)
}
- //def readDouble(): Double = java.lang.Double.longBitsToDouble(readUnsignedVarLong())
def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
override def close(): Unit = s.close()
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
index 24699dfdd38b0..fa533a512d53b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
@@ -26,7 +26,7 @@ import org.apache.spark.graphx.PartitionStrategy._
*/
object Analytics extends Logging {
- def main(args: Array[String]) = {
+ def main(args: Array[String]): Unit = {
val host = args(0)
val taskType = args(1)
val fname = args(2)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
index 014a7335f85cc..087b1156f690b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala
@@ -65,7 +65,7 @@ private[graphx] object BytecodeUtils {
val finder = new MethodInvocationFinder(c.getName, m)
getClassReader(c).accept(finder, 0)
for (classMethod <- finder.methodsInvoked) {
- //println(classMethod)
+ // println(classMethod)
if (classMethod._1 == targetClass && classMethod._2 == targetMethod) {
return true
} else if (!seen.contains(classMethod)) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
index f841846c0e510..a3c8de3f9068f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala
@@ -123,7 +123,7 @@ object GraphGenerators {
* the dimensions of the adjacency matrix
*/
private def addEdge(numVertices: Int): Edge[Int] = {
- //val (src, dst) = chooseCell(numVertices/2.0, numVertices/2.0, numVertices/2.0)
+ // val (src, dst) = chooseCell(numVertices/2.0, numVertices/2.0, numVertices/2.0)
val v = math.round(numVertices.toFloat/2.0).toInt
val (src, dst) = chooseCell(v, v, v)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
new file mode 100644
index 0000000000000..9cbb2d2acdc2d
--- /dev/null
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.graphx._
+
+class EdgeTripletIteratorSuite extends FunSuite {
+ test("iterator.toList") {
+ val builder = new EdgePartitionBuilder[Int]
+ builder.add(1, 2, 0)
+ builder.add(1, 3, 0)
+ builder.add(1, 4, 0)
+ val vidmap = new VertexIdToIndexMap
+ vidmap.add(1)
+ vidmap.add(2)
+ vidmap.add(3)
+ vidmap.add(4)
+ val vs = Array.fill(vidmap.capacity)(0)
+ val iter = new EdgeTripletIterator[Int, Int](vidmap, vs, builder.toEdgePartition)
+ val result = iter.toList.map(et => (et.srcId, et.dstId))
+ assert(result === Seq((1, 2), (1, 3), (1, 4)))
+ }
+}
diff --git a/make-distribution.sh b/make-distribution.sh
index 6bc6819d8da92..5c780fcbda863 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -128,7 +128,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then
TACHYON_VERSION="0.4.1"
TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz"
- TMPD=`mktemp -d`
+ TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'`
pushd $TMPD > /dev/null
echo "Fetchting tachyon tgz"
@@ -139,7 +139,13 @@ if [ "$SPARK_TACHYON" == "true" ]; then
mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web"
cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon"
cp -r "tachyon-${TACHYON_VERSION}"/src/main/java/tachyon/web/resources "$DISTDIR/tachyon/src/main/java/tachyon/web"
- sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../../jars/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh"
+
+ if [[ `uname -a` == Darwin* ]]; then
+ # need to run sed differently on osx
+ nl=$'\n'; sed -i "" -e "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\\$nl export TACHYON_JAR=\$TACHYON_HOME/../jars/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh"
+ else
+ sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../jars/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh"
+ fi
popd > /dev/null
rm -rf $TMPD
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fec1cc94b2642..e7ce00efc4af6 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -58,7 +58,7 @@
org.jblas
jblas
- 1.2.3
+ ${jblas.version}
org.scalanlp
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 3449c698da60b..2df5b0d02b699 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -110,16 +110,16 @@ class PythonMLLibAPI extends Serializable {
private def trainRegressionModel(
trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
- dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
- java.util.LinkedList[java.lang.Object] = {
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes)
- LabeledPoint(x(0), x.slice(1, x.length))
+ LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
})
val initialWeights = deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(serializeDoubleVector(model.weights))
+ ret.add(serializeDoubleVector(model.weights.toArray))
ret.add(model.intercept: java.lang.Double)
ret
}
@@ -127,75 +127,127 @@ class PythonMLLibAPI extends Serializable {
/**
* Java stub for Python mllib LinearRegressionWithSGD.train()
*/
- def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
- numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+ def trainLinearRegressionModelWithSGD(
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int,
+ stepSize: Double,
+ miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- trainRegressionModel((data, initialWeights) =>
- LinearRegressionWithSGD.train(data, numIterations, stepSize,
- miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA)
+ trainRegressionModel(
+ (data, initialWeights) =>
+ LinearRegressionWithSGD.train(
+ data,
+ numIterations,
+ stepSize,
+ miniBatchFraction,
+ Vectors.dense(initialWeights)),
+ dataBytesJRDD,
+ initialWeightsBA)
}
/**
* Java stub for Python mllib LassoWithSGD.train()
*/
- def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
- stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ def trainLassoModelWithSGD(
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double,
+ miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- trainRegressionModel((data, initialWeights) =>
- LassoWithSGD.train(data, numIterations, stepSize, regParam,
- miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA)
+ trainRegressionModel(
+ (data, initialWeights) =>
+ LassoWithSGD.train(
+ data,
+ numIterations,
+ stepSize,
+ regParam,
+ miniBatchFraction,
+ Vectors.dense(initialWeights)),
+ dataBytesJRDD,
+ initialWeightsBA)
}
/**
* Java stub for Python mllib RidgeRegressionWithSGD.train()
*/
- def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
- stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ def trainRidgeModelWithSGD(
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double,
+ miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- trainRegressionModel((data, initialWeights) =>
- RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
- miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA)
+ trainRegressionModel(
+ (data, initialWeights) =>
+ RidgeRegressionWithSGD.train(
+ data,
+ numIterations,
+ stepSize,
+ regParam,
+ miniBatchFraction,
+ Vectors.dense(initialWeights)),
+ dataBytesJRDD,
+ initialWeightsBA)
}
/**
* Java stub for Python mllib SVMWithSGD.train()
*/
- def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
- stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ def trainSVMModelWithSGD(
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double,
+ miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- trainRegressionModel((data, initialWeights) =>
- SVMWithSGD.train(data, numIterations, stepSize, regParam,
- miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA)
+ trainRegressionModel(
+ (data, initialWeights) =>
+ SVMWithSGD.train(
+ data,
+ numIterations,
+ stepSize,
+ regParam,
+ miniBatchFraction,
+ Vectors.dense(initialWeights)),
+ dataBytesJRDD,
+ initialWeightsBA)
}
/**
* Java stub for Python mllib LogisticRegressionWithSGD.train()
*/
- def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]],
- numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+ def trainLogisticRegressionModelWithSGD(
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int,
+ stepSize: Double,
+ miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
- trainRegressionModel((data, initialWeights) =>
- LogisticRegressionWithSGD.train(data, numIterations, stepSize,
- miniBatchFraction, initialWeights),
- dataBytesJRDD, initialWeightsBA)
+ trainRegressionModel(
+ (data, initialWeights) =>
+ LogisticRegressionWithSGD.train(
+ data,
+ numIterations,
+ stepSize,
+ miniBatchFraction,
+ Vectors.dense(initialWeights)),
+ dataBytesJRDD,
+ initialWeightsBA)
}
/**
* Java stub for NaiveBayes.train()
*/
- def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double)
- : java.util.List[java.lang.Object] =
- {
+ def trainNaiveBayes(
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ lambda: Double): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => {
val x = deserializeDoubleVector(xBytes)
- LabeledPoint(x(0), x.slice(1, x.length))
+ LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
})
val model = NaiveBayes.train(data, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
+ ret.add(serializeDoubleVector(model.labels))
ret.add(serializeDoubleVector(model.pi))
ret.add(serializeDoubleMatrix(model.theta))
ret
@@ -204,9 +256,12 @@ class PythonMLLibAPI extends Serializable {
/**
* Java stub for Python mllib KMeans.train()
*/
- def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
- maxIterations: Int, runs: Int, initializationMode: String):
- java.util.List[java.lang.Object] = {
+ def trainKMeansModel(
+ dataBytesJRDD: JavaRDD[Array[Byte]],
+ k: Int,
+ maxIterations: Int,
+ runs: Int,
+ initializationMode: String): java.util.List[java.lang.Object] = {
val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes)))
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
@@ -259,8 +314,12 @@ class PythonMLLibAPI extends Serializable {
* needs to be taken in the Python code to ensure it gets freed on exit; see
* the Py4J documentation.
*/
- def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
- iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = {
+ def trainALSModel(
+ ratingsBytesJRDD: JavaRDD[Array[Byte]],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
ALS.train(ratings, rank, iterations, lambda, blocks)
}
@@ -271,8 +330,13 @@ class PythonMLLibAPI extends Serializable {
* Extra care needs to be taken in the Python code to ensure it gets freed on
* exit; see the Py4J documentation.
*/
- def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int,
- iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = {
+ def trainImplicitALSModel(
+ ratingsBytesJRDD: JavaRDD[Array[Byte]],
+ rank: Int,
+ iterations: Int,
+ lambda: Double,
+ blocks: Int,
+ alpha: Double): MatrixFactorizationModel = {
val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index 391f5b9b7a7de..bd10e2e9e10e2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -17,22 +17,27 @@
package org.apache.spark.mllib.classification
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
+/**
+ * Represents a classification model that predicts to which of a set of categories an example
+ * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc.
+ */
trait ClassificationModel extends Serializable {
/**
* Predict values for the given data set using the model trained.
*
* @param testData RDD representing data points to be predicted
- * @return RDD[Int] where each entry contains the corresponding prediction
+ * @return an RDD[Double] where each entry contains the corresponding prediction
*/
- def predict(testData: RDD[Array[Double]]): RDD[Double]
+ def predict(testData: RDD[Vector]): RDD[Double]
/**
* Predict values for a single data point using the model trained.
*
* @param testData array representing a single data point
- * @return Int prediction from the trained model
+ * @return predicted category from the trained model
*/
- def predict(testData: Array[Double]): Double
+ def predict(testData: Vector): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index a481f522761e2..798f3a5c94740 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -17,16 +17,12 @@
package org.apache.spark.mllib.classification
-import scala.math.round
-
import org.apache.spark.SparkContext
-import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.mllib.util.DataValidators
-
-import org.jblas.DoubleMatrix
+import org.apache.spark.mllib.util.{DataValidators, MLUtils}
+import org.apache.spark.rdd.RDD
/**
* Classification model trained using Logistic Regression.
@@ -35,15 +31,38 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LogisticRegressionModel(
- override val weights: Array[Double],
+ override val weights: Vector,
override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept)
- with ClassificationModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+
+ private var threshold: Option[Double] = Some(0.5)
+
+ /**
+ * Sets the threshold that separates positive predictions from negative predictions. An example
+ * with prediction score greater than or equal to this threshold is identified as an positive,
+ * and negative otherwise. The default value is 0.5.
+ */
+ def setThreshold(threshold: Double): this.type = {
+ this.threshold = Some(threshold)
+ this
+ }
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
+ /**
+ * Clears the threshold so that `predict` will output raw prediction scores.
+ */
+ def clearThreshold(): this.type = {
+ threshold = None
+ this
+ }
+
+ override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
- val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept
- round(1.0/ (1.0 + math.exp(margin * -1)))
+ val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
+ val score = 1.0/ (1.0 + math.exp(-margin))
+ threshold match {
+ case Some(t) => if (score < t) 0.0 else 1.0
+ case None => score
+ }
}
}
@@ -56,16 +75,15 @@ class LogisticRegressionWithSGD private (
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double)
- extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
- with Serializable {
+ extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
val gradient = new LogisticGradient()
val updater = new SimpleUpdater()
override val optimizer = new GradientDescent(gradient, updater)
- .setStepSize(stepSize)
- .setNumIterations(numIterations)
- .setRegParam(regParam)
- .setMiniBatchFraction(miniBatchFraction)
+ .setStepSize(stepSize)
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
override val validators = List(DataValidators.classificationLabels)
/**
@@ -73,7 +91,7 @@ class LogisticRegressionWithSGD private (
*/
def this() = this(1.0, 100, 0.0, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
+ def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
}
}
@@ -105,11 +123,9 @@ object LogisticRegressionWithSGD {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeights: Array[Double])
- : LogisticRegressionModel =
- {
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
- input, initialWeights)
+ initialWeights: Vector): LogisticRegressionModel = {
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction)
+ .run(input, initialWeights)
}
/**
@@ -128,11 +144,9 @@ object LogisticRegressionWithSGD {
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
- miniBatchFraction: Double)
- : LogisticRegressionModel =
- {
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
- input)
+ miniBatchFraction: Double): LogisticRegressionModel = {
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction)
+ .run(input)
}
/**
@@ -150,9 +164,7 @@ object LogisticRegressionWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int,
- stepSize: Double)
- : LogisticRegressionModel =
- {
+ stepSize: Double): LogisticRegressionModel = {
train(input, numIterations, stepSize, 1.0)
}
@@ -168,9 +180,7 @@ object LogisticRegressionWithSGD {
*/
def train(
input: RDD[LabeledPoint],
- numIterations: Int)
- : LogisticRegressionModel =
- {
+ numIterations: Int): LogisticRegressionModel = {
train(input, numIterations, 1.0, 1.0)
}
@@ -183,7 +193,7 @@ object LogisticRegressionWithSGD {
val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
- println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Weights: " + model.weights)
println("Intercept: " + model.intercept)
sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 6539b2f339465..e956185319a69 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,14 +17,14 @@
package org.apache.spark.mllib.classification
-import scala.collection.mutable
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
-import org.jblas.DoubleMatrix
-
-import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
/**
* Model for Naive Bayes Classifiers.
@@ -32,19 +32,28 @@ import org.apache.spark.mllib.util.MLUtils
* @param pi Log of class priors, whose dimension is C.
* @param theta Log of class conditional probabilities, whose dimension is CxD.
*/
-class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]])
- extends ClassificationModel with Serializable {
-
- // Create a column vector that can be used for predictions
- private val _pi = new DoubleMatrix(pi.length, 1, pi: _*)
- private val _theta = new DoubleMatrix(theta)
+class NaiveBayesModel(
+ val labels: Array[Double],
+ val pi: Array[Double],
+ val theta: Array[Array[Double]]) extends ClassificationModel with Serializable {
+
+ private val brzPi = new BDV[Double](pi)
+ private val brzTheta = new BDM[Double](theta.length, theta(0).length)
+
+ var i = 0
+ while (i < theta.length) {
+ var j = 0
+ while (j < theta(i).length) {
+ brzTheta(i, j) = theta(i)(j)
+ j += 1
+ }
+ i += 1
+ }
- def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict)
+ override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
- def predict(testData: Array[Double]): Double = {
- val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*)
- val result = _pi.add(_theta.mmul(dataMatrix))
- result.argmax()
+ override def predict(testData: Vector): Double = {
+ labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))
}
}
@@ -56,9 +65,8 @@ class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]])
* document classification. By making every vector a 0-1 vector, it can also be used as
* Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*/
-class NaiveBayes private (var lambda: Double)
- extends Serializable with Logging
-{
+class NaiveBayes private (var lambda: Double) extends Serializable with Logging {
+
def this() = this(1.0)
/** Set the smoothing parameter. Default: 1.0. */
@@ -70,45 +78,42 @@ class NaiveBayes private (var lambda: Double)
/**
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
*
- * @param data RDD of (label, array of features) pairs.
+ * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
*/
def run(data: RDD[LabeledPoint]) = {
- // Aggregates all sample points to driver side to get sample count and summed feature vector
- // for each label. The shape of `zeroCombiner` & `aggregated` is:
- //
- // label: Int -> (count: Int, featuresSum: DoubleMatrix)
- val zeroCombiner = mutable.Map.empty[Int, (Int, DoubleMatrix)]
- val aggregated = data.aggregate(zeroCombiner)({ (combiner, point) =>
- point match {
- case LabeledPoint(label, features) =>
- val (count, featuresSum) = combiner.getOrElse(label.toInt, (0, DoubleMatrix.zeros(1)))
- val fs = new DoubleMatrix(features.length, 1, features: _*)
- combiner += label.toInt -> (count + 1, featuresSum.addi(fs))
- }
- }, { (lhs, rhs) =>
- for ((label, (c, fs)) <- rhs) {
- val (count, featuresSum) = lhs.getOrElse(label, (0, DoubleMatrix.zeros(1)))
- lhs(label) = (count + c, featuresSum.addi(fs))
+ // Aggregates term frequencies per label.
+ // TODO: Calling combineByKey and collect creates two stages, we can implement something
+ // TODO: similar to reduceByKeyLocally to save one stage.
+ val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
+ createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector),
+ mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze),
+ mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
+ (c1._1 + c2._1, c1._2 += c2._2)
+ ).collect()
+ val numLabels = aggregated.length
+ var numDocuments = 0L
+ aggregated.foreach { case (_, (n, _)) =>
+ numDocuments += n
+ }
+ val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
+ val labels = new Array[Double](numLabels)
+ val pi = new Array[Double](numLabels)
+ val theta = Array.fill(numLabels)(new Array[Double](numFeatures))
+ val piLogDenom = math.log(numDocuments + numLabels * lambda)
+ var i = 0
+ aggregated.foreach { case (label, (n, sumTermFreqs)) =>
+ labels(i) = label
+ val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
+ pi(i) = math.log(n + lambda) - piLogDenom
+ var j = 0
+ while (j < numFeatures) {
+ theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
+ j += 1
}
- lhs
- })
-
- // Kinds of label
- val C = aggregated.size
- // Total sample count
- val N = aggregated.values.map(_._1).sum
-
- val pi = new Array[Double](C)
- val theta = new Array[Array[Double]](C)
- val piLogDenom = math.log(N + C * lambda)
-
- for ((label, (count, fs)) <- aggregated) {
- val thetaLogDenom = math.log(fs.sum() + fs.length * lambda)
- pi(label) = math.log(count + lambda) - piLogDenom
- theta(label) = fs.toArray.map(f => math.log(f + lambda) - thetaLogDenom)
+ i += 1
}
- new NaiveBayesModel(pi, theta)
+ new NaiveBayesModel(labels, pi, theta)
}
}
@@ -158,8 +163,9 @@ object NaiveBayes {
} else {
NaiveBayes.train(data, args(2).toDouble)
}
- println("Pi: " + model.pi.mkString("[", ", ", "]"))
- println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]"))
+
+ println("Pi\n: " + model.pi)
+ println("Theta:\n" + model.theta)
sc.stop()
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 6dff29dfb45cc..e31a08899f8bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -18,13 +18,11 @@
package org.apache.spark.mllib.classification
import org.apache.spark.SparkContext
-import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.mllib.util.DataValidators
-
-import org.jblas.DoubleMatrix
+import org.apache.spark.mllib.util.{DataValidators, MLUtils}
+import org.apache.spark.rdd.RDD
/**
* Model for Support Vector Machines (SVMs).
@@ -33,15 +31,37 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class SVMModel(
- override val weights: Array[Double],
+ override val weights: Vector,
override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept)
- with ClassificationModel with Serializable {
+ extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+
+ private var threshold: Option[Double] = Some(0.0)
+
+ /**
+ * Sets the threshold that separates positive predictions from negative predictions. An example
+ * with prediction score greater than or equal to this threshold is identified as an positive,
+ * and negative otherwise. The default value is 0.0.
+ */
+ def setThreshold(threshold: Double): this.type = {
+ this.threshold = Some(threshold)
+ this
+ }
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
+ /**
+ * Clears the threshold so that `predict` will output raw prediction scores.
+ */
+ def clearThreshold(): this.type = {
+ threshold = None
+ this
+ }
+
+ override def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
- val margin = dataMatrix.dot(weightMatrix) + intercept
- if (margin < 0) 0.0 else 1.0
+ val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
+ threshold match {
+ case Some(t) => if (margin < 0) 0.0 else 1.0
+ case None => margin
+ }
}
}
@@ -71,7 +91,7 @@ class SVMWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
+ def createModel(weights: Vector, intercept: Double) = {
new SVMModel(weights, intercept)
}
}
@@ -103,11 +123,9 @@ object SVMWithSGD {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Array[Double])
- : SVMModel =
- {
- new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
- initialWeights)
+ initialWeights: Vector): SVMModel = {
+ new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction)
+ .run(input, initialWeights)
}
/**
@@ -127,9 +145,7 @@ object SVMWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double,
- miniBatchFraction: Double)
- : SVMModel =
- {
+ miniBatchFraction: Double): SVMModel = {
new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
@@ -149,9 +165,7 @@ object SVMWithSGD {
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
- regParam: Double)
- : SVMModel =
- {
+ regParam: Double): SVMModel = {
train(input, numIterations, stepSize, regParam, 1.0)
}
@@ -165,11 +179,7 @@ object SVMWithSGD {
* @param numIterations Number of iterations of gradient descent to run.
* @return a SVMModel which has the weights and offset from training.
*/
- def train(
- input: RDD[LabeledPoint],
- numIterations: Int)
- : SVMModel =
- {
+ def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
train(input, numIterations, 1.0, 1.0, 1.0)
}
@@ -181,7 +191,8 @@ object SVMWithSGD {
val sc = new SparkContext(args(0), "SVM")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
- println("Weights: " + model.weights.mkString("[", ", ", "]"))
+
+ println("Weights: " + model.weights)
println("Intercept: " + model.intercept)
sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index b412738e3f00a..a78503df3134d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -42,8 +42,7 @@ class KMeans private (
var runs: Int,
var initializationMode: String,
var initializationSteps: Int,
- var epsilon: Double)
- extends Serializable with Logging {
+ var epsilon: Double) extends Serializable with Logging {
def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
/** Set the number of clusters to create (k). Default: 2. */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 01c1501548f87..2cea58cd3fd22 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -54,6 +54,12 @@ trait Vector extends Serializable {
* Converts the instance to a breeze vector.
*/
private[mllib] def toBreeze: BV[Double]
+
+ /**
+ * Gets the value of the ith element.
+ * @param i index
+ */
+ private[mllib] def apply(i: Int): Double = toBreeze(i)
}
/**
@@ -145,6 +151,8 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def toArray: Array[Double] = values
private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
+
+ override def apply(i: Int) = values(i)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 82124703da6cd..20654284965ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -17,7 +17,9 @@
package org.apache.spark.mllib.optimization
-import org.jblas.DoubleMatrix
+import breeze.linalg.{axpy => brzAxpy}
+
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* Class used to compute the gradient for a loss function, given a single data point.
@@ -26,17 +28,26 @@ abstract class Gradient extends Serializable {
/**
* Compute the gradient and loss given the features of a single data point.
*
- * @param data - Feature values for one data point. Column matrix of size dx1
- * where d is the number of features.
- * @param label - Label for this data item.
- * @param weights - Column matrix containing weights for every feature.
+ * @param data features for one data point
+ * @param label label for this data point
+ * @param weights weights/coefficients corresponding to features
*
- * @return A tuple of 2 elements. The first element is a column matrix containing the computed
- * gradient and the second element is the loss computed at this data point.
+ * @return (gradient: Vector, loss: Double)
+ */
+ def compute(data: Vector, label: Double, weights: Vector): (Vector, Double)
+
+ /**
+ * Compute the gradient and loss given the features of a single data point,
+ * add the gradient to a provided vector to avoid creating new objects, and return loss.
*
+ * @param data features for one data point
+ * @param label label for this data point
+ * @param weights weights/coefficients corresponding to features
+ * @param cumGradient the computed gradient will be added to this vector
+ *
+ * @return loss
*/
- def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
- (DoubleMatrix, Double)
+ def compute(data: Vector, label: Double, weights: Vector, cumGradient: Vector): Double
}
/**
@@ -44,12 +55,12 @@ abstract class Gradient extends Serializable {
* See also the documentation for the precise formulation.
*/
class LogisticGradient extends Gradient {
- override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
- (DoubleMatrix, Double) = {
- val margin: Double = -1.0 * data.dot(weights)
+ override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
+ val brzData = data.toBreeze
+ val brzWeights = weights.toBreeze
+ val margin: Double = -1.0 * brzWeights.dot(brzData)
val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
-
- val gradient = data.mul(gradientMultiplier)
+ val gradient = brzData * gradientMultiplier
val loss =
if (label > 0) {
math.log(1 + math.exp(margin))
@@ -57,7 +68,26 @@ class LogisticGradient extends Gradient {
math.log(1 + math.exp(margin)) - margin
}
- (gradient, loss)
+ (Vectors.fromBreeze(gradient), loss)
+ }
+
+ override def compute(
+ data: Vector,
+ label: Double,
+ weights: Vector,
+ cumGradient: Vector): Double = {
+ val brzData = data.toBreeze
+ val brzWeights = weights.toBreeze
+ val margin: Double = -1.0 * brzWeights.dot(brzData)
+ val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
+
+ brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze)
+
+ if (label > 0) {
+ math.log(1 + math.exp(margin))
+ } else {
+ math.log(1 + math.exp(margin)) - margin
+ }
}
}
@@ -68,14 +98,28 @@ class LogisticGradient extends Gradient {
* See also the documentation for the precise formulation.
*/
class LeastSquaresGradient extends Gradient {
- override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
- (DoubleMatrix, Double) = {
- val diff: Double = data.dot(weights) - label
-
+ override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
+ val brzData = data.toBreeze
+ val brzWeights = weights.toBreeze
+ val diff = brzWeights.dot(brzData) - label
val loss = diff * diff
- val gradient = data.mul(2.0 * diff)
+ val gradient = brzData * (2.0 * diff)
- (gradient, loss)
+ (Vectors.fromBreeze(gradient), loss)
+ }
+
+ override def compute(
+ data: Vector,
+ label: Double,
+ weights: Vector,
+ cumGradient: Vector): Double = {
+ val brzData = data.toBreeze
+ val brzWeights = weights.toBreeze
+ val diff = brzWeights.dot(brzData) - label
+
+ brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze)
+
+ diff * diff
}
}
@@ -85,19 +129,40 @@ class LeastSquaresGradient extends Gradient {
* NOTE: This assumes that the labels are {0,1}
*/
class HingeGradient extends Gradient {
- override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
- (DoubleMatrix, Double) = {
+ override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
+ val brzData = data.toBreeze
+ val brzWeights = weights.toBreeze
+ val dotProduct = brzWeights.dot(brzData)
+
+ // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
+ // Therefore the gradient is -(2y - 1)*x
+ val labelScaled = 2 * label - 1.0
+
+ if (1.0 > labelScaled * dotProduct) {
+ (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct)
+ } else {
+ (Vectors.dense(new Array[Double](weights.size)), 0.0)
+ }
+ }
- val dotProduct = data.dot(weights)
+ override def compute(
+ data: Vector,
+ label: Double,
+ weights: Vector,
+ cumGradient: Vector): Double = {
+ val brzData = data.toBreeze
+ val brzWeights = weights.toBreeze
+ val dotProduct = brzWeights.dot(brzData)
// Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
// Therefore the gradient is -(2y - 1)*x
val labelScaled = 2 * label - 1.0
if (1.0 > labelScaled * dotProduct) {
- (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct)
+ brzAxpy(-labelScaled, brzData, cumGradient.toBreeze)
+ 1.0 - labelScaled * dotProduct
} else {
- (DoubleMatrix.zeros(1, weights.length), 0.0)
+ 0.0
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index b967b22e818d3..d0777ffd63ff8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.optimization
-import org.apache.spark.Logging
-import org.apache.spark.rdd.RDD
+import scala.collection.mutable.ArrayBuffer
-import org.jblas.DoubleMatrix
+import breeze.linalg.{Vector => BV, DenseVector => BDV}
-import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* Class used to solve an optimization problem using Gradient Descent.
@@ -91,18 +92,16 @@ class GradientDescent(var gradient: Gradient, var updater: Updater)
this
}
- def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double])
- : Array[Double] = {
-
- val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD(
- data,
- gradient,
- updater,
- stepSize,
- numIterations,
- regParam,
- miniBatchFraction,
- initialWeights)
+ def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
+ val (weights, _) = GradientDescent.runMiniBatchSGD(
+ data,
+ gradient,
+ updater,
+ stepSize,
+ numIterations,
+ regParam,
+ miniBatchFraction,
+ initialWeights)
weights
}
@@ -133,14 +132,14 @@ object GradientDescent extends Logging {
* stochastic loss computed for every iteration.
*/
def runMiniBatchSGD(
- data: RDD[(Double, Array[Double])],
+ data: RDD[(Double, Vector)],
gradient: Gradient,
updater: Updater,
stepSize: Double,
numIterations: Int,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Array[Double]) : (Array[Double], Array[Double]) = {
+ initialWeights: Vector): (Vector, Array[Double]) = {
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
@@ -148,24 +147,27 @@ object GradientDescent extends Logging {
val miniBatchSize = nexamples * miniBatchFraction
// Initialize weights as a column vector
- var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*)
+ var weights = Vectors.dense(initialWeights.toArray)
/**
* For the first iteration, the regVal will be initialized as sum of sqrt of
* weights if it's L2 update; for L1 update; the same logic is followed.
*/
var regVal = updater.compute(
- weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2
+ weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
for (i <- 1 to numIterations) {
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
- val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map {
- case (y, features) =>
- val featuresCol = new DoubleMatrix(features.length, 1, features:_*)
- val (grad, loss) = gradient.compute(featuresCol, y, weights)
- (grad, loss)
- }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2))
+ val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
+ .aggregate((BDV.zeros[Double](weights.size), 0.0))(
+ seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
+ val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
+ (grad, loss + l)
+ },
+ combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
+ (grad1 += grad2, loss1 + loss2)
+ })
/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
@@ -173,7 +175,7 @@ object GradientDescent extends Logging {
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(
- weights, gradientSum.div(miniBatchSize), stepSize, i, regParam)
+ weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
weights = update._1
regVal = update._2
}
@@ -181,6 +183,6 @@ object GradientDescent extends Logging {
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
stochasticLossHistory.takeRight(10).mkString(", ")))
- (weights.toArray, stochasticLossHistory.toArray)
+ (weights, stochasticLossHistory.toArray)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala
index 94d30b56f212b..f9ce908a5f3b0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala
@@ -19,11 +19,12 @@ package org.apache.spark.mllib.optimization
import org.apache.spark.rdd.RDD
-trait Optimizer {
+import org.apache.spark.mllib.linalg.Vector
+
+trait Optimizer extends Serializable {
/**
* Solve the provided convex optimization problem.
*/
- def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double]
-
+ def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
index bf8f731459e99..3b7754cd7ac28 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala
@@ -18,7 +18,10 @@
package org.apache.spark.mllib.optimization
import scala.math._
-import org.jblas.DoubleMatrix
+
+import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV}
+
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* Class used to perform steps (weight update) using Gradient Descent methods.
@@ -47,8 +50,12 @@ abstract class Updater extends Serializable {
* @return A tuple of 2 elements. The first element is a column matrix containing updated weights,
* and the second element is the regularization value computed using updated weights.
*/
- def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int,
- regParam: Double): (DoubleMatrix, Double)
+ def compute(
+ weightsOld: Vector,
+ gradient: Vector,
+ stepSize: Double,
+ iter: Int,
+ regParam: Double): (Vector, Double)
}
/**
@@ -56,11 +63,17 @@ abstract class Updater extends Serializable {
* Uses a step-size decreasing with the square root of the number of iterations.
*/
class SimpleUpdater extends Updater {
- override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
- stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
+ override def compute(
+ weightsOld: Vector,
+ gradient: Vector,
+ stepSize: Double,
+ iter: Int,
+ regParam: Double): (Vector, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val step = gradient.mul(thisIterStepSize)
- (weightsOld.sub(step), 0)
+ val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
+ brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
+
+ (Vectors.fromBreeze(brzWeights), 0)
}
}
@@ -83,19 +96,26 @@ class SimpleUpdater extends Updater {
* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
*/
class L1Updater extends Updater {
- override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
- stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
+ override def compute(
+ weightsOld: Vector,
+ gradient: Vector,
+ stepSize: Double,
+ iter: Int,
+ regParam: Double): (Vector, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
- val step = gradient.mul(thisIterStepSize)
// Take gradient step
- val newWeights = weightsOld.sub(step)
+ val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
+ brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
// Apply proximal operator (soft thresholding)
val shrinkageVal = regParam * thisIterStepSize
- (0 until newWeights.length).foreach { i =>
- val wi = newWeights.get(i)
- newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal))
+ var i = 0
+ while (i < brzWeights.length) {
+ val wi = brzWeights(i)
+ brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal)
+ i += 1
}
- (newWeights, newWeights.norm1 * regParam)
+
+ (Vectors.fromBreeze(brzWeights), brzNorm(brzWeights, 1.0) * regParam)
}
}
@@ -105,16 +125,23 @@ class L1Updater extends Updater {
* Uses a step-size decreasing with the square root of the number of iterations.
*/
class SquaredL2Updater extends Updater {
- override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
- stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
- val thisIterStepSize = stepSize / math.sqrt(iter)
- val step = gradient.mul(thisIterStepSize)
+ override def compute(
+ weightsOld: Vector,
+ gradient: Vector,
+ stepSize: Double,
+ iter: Int,
+ regParam: Double): (Vector, Double) = {
// add up both updates from the gradient of the loss (= step) as well as
// the gradient of the regularizer (= regParam * weightsOld)
// w' = w - thisIterStepSize * (gradient + regParam * w)
// w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
- val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step)
- (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam)
+ val thisIterStepSize = stepSize / math.sqrt(iter)
+ val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
+ brzWeights :*= (1.0 - thisIterStepSize * regParam)
+ brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
+ val norm = brzNorm(brzWeights, 2.0)
+
+ (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index b9621530efa22..80dc0f12ff84f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -17,11 +17,12 @@
package org.apache.spark.mllib.regression
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
-
-import org.jblas.DoubleMatrix
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
/**
* GeneralizedLinearModel (GLM) represents a model trained using
@@ -31,12 +32,9 @@ import org.jblas.DoubleMatrix
* @param weights Weights computed for every feature.
* @param intercept Intercept computed for this model.
*/
-abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double)
+abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double)
extends Serializable {
- // Create a column vector that can be used for predictions
- private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
-
/**
* Predict the result given a data point and the weights learned.
*
@@ -44,8 +42,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
* @param weightMatrix Column vector containing the weights of the model
* @param intercept Intercept of the model.
*/
- def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double): Double
+ protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double
/**
* Predict values for the given data set using the model trained.
@@ -53,16 +50,13 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
* @param testData RDD representing data points to be predicted
* @return RDD[Double] where each entry contains the corresponding prediction
*/
- def predict(testData: RDD[Array[Double]]): RDD[Double] = {
+ def predict(testData: RDD[Vector]): RDD[Double] = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
- val localWeights = weightsMatrix
+ val localWeights = weights
val localIntercept = intercept
- testData.map { x =>
- val dataMatrix = new DoubleMatrix(1, x.length, x:_*)
- predictPoint(dataMatrix, localWeights, localIntercept)
- }
+ testData.map(v => predictPoint(v, localWeights, localIntercept))
}
/**
@@ -71,14 +65,13 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
* @param testData array representing a single data point
* @return Double prediction from the trained model
*/
- def predict(testData: Array[Double]): Double = {
- val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
- predictPoint(dataMat, weightsMatrix, intercept)
+ def predict(testData: Vector): Double = {
+ predictPoint(testData, weights, intercept)
}
}
/**
- * GeneralizedLinearAlgorithm implements methods to train a Genearalized Linear Model (GLM).
+ * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM).
* This class should be extended with an Optimizer to create a new GLM.
*/
abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
@@ -88,6 +81,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
val optimizer: Optimizer
+ /** Whether to add intercept (default: true). */
protected var addIntercept: Boolean = true
protected var validateData: Boolean = true
@@ -95,7 +89,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
/**
* Create a model given the weights and intercept
*/
- protected def createModel(weights: Array[Double], intercept: Double): M
+ protected def createModel(weights: Vector, intercept: Double): M
/**
* Set if the algorithm should add an intercept. Default true.
@@ -117,17 +111,27 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
*/
- def run(input: RDD[LabeledPoint]) : M = {
- val nfeatures: Int = input.first().features.length
- val initialWeights = new Array[Double](nfeatures)
+ def run(input: RDD[LabeledPoint]): M = {
+ val numFeatures: Int = input.first().features.size
+ val initialWeights = Vectors.dense(new Array[Double](numFeatures))
run(input, initialWeights)
}
+ /** Prepends one to the input vector. */
+ private def prependOne(vector: Vector): Vector = {
+ val vector1 = vector.toBreeze match {
+ case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv)
+ case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv)
+ case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
+ }
+ Vectors.fromBreeze(vector1)
+ }
+
/**
* Run the algorithm with the configured parameters on an input RDD
* of LabeledPoint entries starting from the initial weights provided.
*/
- def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {
+ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
// Check the data properties before running the optimizer
if (validateData && !validators.forall(func => func(input))) {
@@ -136,25 +140,27 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
- input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
+ input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features)))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}
val initialWeightsWithIntercept = if (addIntercept) {
- initialWeights.+:(1.0)
+ prependOne(initialWeights)
} else {
initialWeights
}
- val weights = optimizer.optimize(data, initialWeightsWithIntercept)
- val intercept = weights(0)
- val weightsScaled = weights.tail
+ val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
- val model = createModel(weightsScaled, intercept)
+ val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0
+ val weights =
+ if (addIntercept) {
+ Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size))
+ } else {
+ weightsWithIntercept
+ }
- logInfo("Final model weights " + model.weights.mkString(","))
- logInfo("Final model intercept " + model.intercept)
- model
+ createModel(weights, intercept)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 1a18292fe3f3b..3deab1ab785b9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -17,14 +17,16 @@
package org.apache.spark.mllib.regression
+import org.apache.spark.mllib.linalg.Vector
+
/**
* Class that represents the features and labels of a data point.
*
* @param label Label for this data point.
* @param features List of features for this data point.
*/
-case class LabeledPoint(label: Double, features: Array[Double]) {
+case class LabeledPoint(label: Double, features: Vector) {
override def toString: String = {
- "LabeledPoint(%s, %s)".format(label, features.mkString("[", ", ", "]"))
+ "LabeledPoint(%s, %s)".format(label, features)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index fb2bc9b92a51c..25920d0dc976e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -17,12 +17,11 @@
package org.apache.spark.mllib.regression
-import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.util.MLUtils
-
-import org.jblas.DoubleMatrix
+import org.apache.spark.rdd.RDD
/**
* Regression model trained using Lasso.
@@ -31,14 +30,16 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LassoModel(
- override val weights: Array[Double],
+ override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double) = {
- dataMatrix.dot(weightMatrix) + intercept
+ override protected def predictPoint(
+ dataMatrix: Vector,
+ weightMatrix: Vector,
+ intercept: Double): Double = {
+ weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
}
@@ -55,8 +56,7 @@ class LassoWithSGD private (
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double)
- extends GeneralizedLinearAlgorithm[LassoModel]
- with Serializable {
+ extends GeneralizedLinearAlgorithm[LassoModel] with Serializable {
val gradient = new LeastSquaresGradient()
val updater = new L1Updater()
@@ -66,47 +66,21 @@ class LassoWithSGD private (
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept, so set this to false.
- setIntercept(false)
-
- var yMean = 0.0
- var xColMean: DoubleMatrix = _
- var xColSd: DoubleMatrix = _
+ super.setIntercept(false)
/**
* Construct a Lasso object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
- val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
- val weightsScaled = weightsMat.div(xColSd)
- val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
-
- new LassoModel(weightsScaled.data, interceptScaled)
+ override def setIntercept(addIntercept: Boolean): this.type = {
+ // TODO: Support adding intercept.
+ if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
+ this
}
- override def run(
- input: RDD[LabeledPoint],
- initialWeights: Array[Double])
- : LassoModel =
- {
- val nfeatures: Int = input.first.features.length
- val nexamples: Long = input.count()
-
- // To avoid penalizing the intercept, we center and scale the data.
- val stats = MLUtils.computeStats(input, nfeatures, nexamples)
- yMean = stats._1
- xColMean = stats._2
- xColSd = stats._3
-
- val normalizedData = input.map { point =>
- val yNormalized = point.label - yMean
- val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*)
- val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
- LabeledPoint(yNormalized, featuresNormalized.toArray)
- }
-
- super.run(normalizedData, initialWeights)
+ override protected def createModel(weights: Vector, intercept: Double) = {
+ new LassoModel(weights, intercept)
}
}
@@ -136,11 +110,9 @@ object LassoWithSGD {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Array[Double])
- : LassoModel =
- {
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
- initialWeights)
+ initialWeights: Vector): LassoModel = {
+ new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction)
+ .run(input, initialWeights)
}
/**
@@ -160,9 +132,7 @@ object LassoWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double,
- miniBatchFraction: Double)
- : LassoModel =
- {
+ miniBatchFraction: Double): LassoModel = {
new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
@@ -182,9 +152,7 @@ object LassoWithSGD {
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
- regParam: Double)
- : LassoModel =
- {
+ regParam: Double): LassoModel = {
train(input, numIterations, stepSize, regParam, 1.0)
}
@@ -200,9 +168,7 @@ object LassoWithSGD {
*/
def train(
input: RDD[LabeledPoint],
- numIterations: Int)
- : LassoModel =
- {
+ numIterations: Int): LassoModel = {
train(input, numIterations, 1.0, 1.0, 1.0)
}
@@ -214,7 +180,8 @@ object LassoWithSGD {
val sc = new SparkContext(args(0), "Lasso")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
- println("Weights: " + model.weights.mkString("[", ", ", "]"))
+
+ println("Weights: " + model.weights)
println("Intercept: " + model.intercept)
sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 8ee40addb25d9..9ed927994e795 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -19,11 +19,10 @@ package org.apache.spark.mllib.regression
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.util.MLUtils
-import org.jblas.DoubleMatrix
-
/**
* Regression model trained using LinearRegression.
*
@@ -31,14 +30,15 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LinearRegressionModel(
- override val weights: Array[Double],
- override val intercept: Double)
- extends GeneralizedLinearModel(weights, intercept)
- with RegressionModel with Serializable {
-
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double) = {
- dataMatrix.dot(weightMatrix) + intercept
+ override val weights: Vector,
+ override val intercept: Double)
+ extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
+
+ override protected def predictPoint(
+ dataMatrix: Vector,
+ weightMatrix: Vector,
+ intercept: Double): Double = {
+ weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
}
@@ -55,8 +55,7 @@ class LinearRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var miniBatchFraction: Double)
- extends GeneralizedLinearAlgorithm[LinearRegressionModel]
- with Serializable {
+ extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
@@ -69,7 +68,7 @@ class LinearRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
+ override protected def createModel(weights: Vector, intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
@@ -98,11 +97,9 @@ object LinearRegressionWithSGD {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeights: Array[Double])
- : LinearRegressionModel =
- {
- new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input,
- initialWeights)
+ initialWeights: Vector): LinearRegressionModel = {
+ new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction)
+ .run(input, initialWeights)
}
/**
@@ -120,9 +117,7 @@ object LinearRegressionWithSGD {
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
- miniBatchFraction: Double)
- : LinearRegressionModel =
- {
+ miniBatchFraction: Double): LinearRegressionModel = {
new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input)
}
@@ -140,9 +135,7 @@ object LinearRegressionWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int,
- stepSize: Double)
- : LinearRegressionModel =
- {
+ stepSize: Double): LinearRegressionModel = {
train(input, numIterations, stepSize, 1.0)
}
@@ -158,9 +151,7 @@ object LinearRegressionWithSGD {
*/
def train(
input: RDD[LabeledPoint],
- numIterations: Int)
- : LinearRegressionModel =
- {
+ numIterations: Int): LinearRegressionModel = {
train(input, numIterations, 1.0, 1.0)
}
@@ -172,7 +163,7 @@ object LinearRegressionWithSGD {
val sc = new SparkContext(args(0), "LinearRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
- println("Weights: " + model.weights.mkString("[", ", ", "]"))
+ println("Weights: " + model.weights)
println("Intercept: " + model.intercept)
sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index 423afc32d665c..5e4b8a345b1c5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.regression
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vector
trait RegressionModel extends Serializable {
/**
@@ -26,7 +27,7 @@ trait RegressionModel extends Serializable {
* @param testData RDD representing data points to be predicted
* @return RDD[Double] where each entry contains the corresponding prediction
*/
- def predict(testData: RDD[Array[Double]]): RDD[Double]
+ def predict(testData: RDD[Vector]): RDD[Double]
/**
* Predict values for a single data point using the model trained.
@@ -34,5 +35,5 @@ trait RegressionModel extends Serializable {
* @param testData array representing a single data point
* @return Double prediction from the trained model
*/
- def predict(testData: Array[Double]): Double
+ def predict(testData: Vector): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index c504d3d40c773..1f17d2107f940 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -21,8 +21,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.util.MLUtils
-
-import org.jblas.DoubleMatrix
+import org.apache.spark.mllib.linalg.Vector
/**
* Regression model trained using RidgeRegression.
@@ -31,14 +30,16 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class RidgeRegressionModel(
- override val weights: Array[Double],
+ override val weights: Vector,
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
- override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double) = {
- dataMatrix.dot(weightMatrix) + intercept
+ override protected def predictPoint(
+ dataMatrix: Vector,
+ weightMatrix: Vector,
+ intercept: Double): Double = {
+ weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
}
}
@@ -55,8 +56,7 @@ class RidgeRegressionWithSGD private (
var numIterations: Int,
var regParam: Double,
var miniBatchFraction: Double)
- extends GeneralizedLinearAlgorithm[RidgeRegressionModel]
- with Serializable {
+ extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable {
val gradient = new LeastSquaresGradient()
val updater = new SquaredL2Updater()
@@ -67,47 +67,21 @@ class RidgeRegressionWithSGD private (
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
- setIntercept(false)
-
- var yMean = 0.0
- var xColMean: DoubleMatrix = _
- var xColSd: DoubleMatrix = _
+ super.setIntercept(false)
/**
* Construct a RidgeRegression object with default parameters
*/
def this() = this(1.0, 100, 1.0, 1.0)
- def createModel(weights: Array[Double], intercept: Double) = {
- val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
- val weightsScaled = weightsMat.div(xColSd)
- val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
-
- new RidgeRegressionModel(weightsScaled.data, interceptScaled)
+ override def setIntercept(addIntercept: Boolean): this.type = {
+ // TODO: Support adding intercept.
+ if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
+ this
}
- override def run(
- input: RDD[LabeledPoint],
- initialWeights: Array[Double])
- : RidgeRegressionModel =
- {
- val nfeatures: Int = input.first().features.length
- val nexamples: Long = input.count()
-
- // To avoid penalizing the intercept, we center and scale the data.
- val stats = MLUtils.computeStats(input, nfeatures, nexamples)
- yMean = stats._1
- xColMean = stats._2
- xColSd = stats._3
-
- val normalizedData = input.map { point =>
- val yNormalized = point.label - yMean
- val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*)
- val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd)
- LabeledPoint(yNormalized, featuresNormalized.toArray)
- }
-
- super.run(normalizedData, initialWeights)
+ override protected def createModel(weights: Vector, intercept: Double) = {
+ new RidgeRegressionModel(weights, intercept)
}
}
@@ -136,9 +110,7 @@ object RidgeRegressionWithSGD {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Array[Double])
- : RidgeRegressionModel =
- {
+ initialWeights: Vector): RidgeRegressionModel = {
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(
input, initialWeights)
}
@@ -159,9 +131,7 @@ object RidgeRegressionWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double,
- miniBatchFraction: Double)
- : RidgeRegressionModel =
- {
+ miniBatchFraction: Double): RidgeRegressionModel = {
new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
@@ -180,9 +150,7 @@ object RidgeRegressionWithSGD {
input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
- regParam: Double)
- : RidgeRegressionModel =
- {
+ regParam: Double): RidgeRegressionModel = {
train(input, numIterations, stepSize, regParam, 1.0)
}
@@ -197,23 +165,22 @@ object RidgeRegressionWithSGD {
*/
def train(
input: RDD[LabeledPoint],
- numIterations: Int)
- : RidgeRegressionModel =
- {
+ numIterations: Int): RidgeRegressionModel = {
train(input, numIterations, 1.0, 1.0, 1.0)
}
def main(args: Array[String]) {
if (args.length != 5) {
- println("Usage: RidgeRegression " +
- " ")
+ println("Usage: RidgeRegression " +
+ " ")
System.exit(1)
}
val sc = new SparkContext(args(0), "RidgeRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble,
args(3).toDouble)
- println("Weights: " + model.weights.mkString("[", ", ", "]"))
+
+ println("Weights: " + model.weights)
println("Intercept: " + model.intercept)
sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
new file mode 100644
index 0000000000000..dee9594a9dd79
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -0,0 +1,1151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.util.control.Breaks._
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+/**
+ * A class that implements a decision tree algorithm for classification and regression. It
+ * supports both continuous and categorical features.
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of algorithm (classification, regression, etc.), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ */
+class DecisionTree private(val strategy: Strategy) extends Serializable with Logging {
+
+ /**
+ * Method to train a decision tree model over an RDD
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+
+ // Cache input RDD for speedup during multiple passes.
+ input.cache()
+ logDebug("algo = " + strategy.algo)
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ logDebug("numSplits = " + bins(0).length)
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ // the max number of nodes possible given the depth of the tree
+ val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1
+ // Initialize an array to hold filters applied to points for each node.
+ val filters = new Array[List[Filter]](maxNumNodes)
+ // The filter at the top node is an empty list.
+ filters(0) = List()
+ // Initialize an array to hold parent impurity calculations for each node.
+ val parentImpurities = new Array[Double](maxNumNodes)
+ // dummy value for top node (updated during first split calculation)
+ val nodes = new Array[Node](maxNumNodes)
+
+
+ /*
+ * The main idea here is to perform level-wise training of the decision tree nodes thus
+ * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
+ * Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
+ * the sample is only used for the split calculation at the node if the sampled would have
+ * still survived the filters of the parent nodes.
+ */
+
+ // TODO: Convert for loop to while loop
+ breakable {
+ for (level <- 0 until maxDepth) {
+
+ logDebug("#####################################")
+ logDebug("level = " + level)
+ logDebug("#####################################")
+
+ // Find best split for all nodes at a level.
+ val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
+ level, filters, splits, bins)
+
+ for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+ // Extract info for nodes at the current level.
+ extractNodeInfo(nodeSplitStats, level, index, nodes)
+ // Extract info for nodes at the next lower level.
+ extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
+ filters)
+ logDebug("final best split = " + nodeSplitStats._1)
+ }
+ require(scala.math.pow(2, level) == splitsStatsForLevel.length)
+ // Check whether all the nodes at the current level at leaves.
+ val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
+ logDebug("all leaf = " + allLeaf)
+ if (allLeaf) break // no more tree construction
+ }
+ }
+
+ // Initialize the top or root node of the tree.
+ val topNode = nodes(0)
+ // Build the full tree using the node info calculated in the level-wise best split calculations.
+ topNode.build(nodes)
+
+ new DecisionTreeModel(topNode, strategy.algo)
+ }
+
+ /**
+ * Extract the decision tree node information for the given tree level and node index
+ */
+ private def extractNodeInfo(
+ nodeSplitStats: (Split, InformationGainStats),
+ level: Int,
+ index: Int,
+ nodes: Array[Node]): Unit = {
+ val split = nodeSplitStats._1
+ val stats = nodeSplitStats._2
+ val nodeIndex = scala.math.pow(2, level).toInt - 1 + index
+ val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
+ val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+ logDebug("Node = " + node)
+ nodes(nodeIndex) = node
+ }
+
+ /**
+ * Extract the decision tree node information for the children of the node
+ */
+ private def extractInfoForLowerLevels(
+ level: Int,
+ index: Int,
+ maxDepth: Int,
+ nodeSplitStats: (Split, InformationGainStats),
+ parentImpurities: Array[Double],
+ filters: Array[List[Filter]]): Unit = {
+ // 0 corresponds to the left child node and 1 corresponds to the right child node.
+ // TODO: Convert to while loop
+ for (i <- 0 to 1) {
+ // Calculate the index of the node from the node level and the index at the current level.
+ val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
+ if (level < maxDepth - 1) {
+ val impurity = if (i == 0) {
+ nodeSplitStats._2.leftImpurity
+ } else {
+ nodeSplitStats._2.rightImpurity
+ }
+ logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
+ // noting the parent impurities
+ parentImpurities(nodeIndex) = impurity
+ // noting the parents filters for the child nodes
+ val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
+ filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
+ for (filter <- filters(nodeIndex)) {
+ logDebug("Filter = " + filter)
+ }
+ }
+ }
+ }
+}
+
+object DecisionTree extends Serializable with Logging {
+
+ /**
+ * Method to train a decision tree model where the instances are represented as an RDD of
+ * (label, features) pairs. The method supports binary classification and regression. For the
+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
+ * classes. The parameters for the algorithm are specified using the strategy parameter.
+ *
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * for DecisionTree
+ * @param strategy The configuration parameters for the tree algorithm which specify the type
+ * of algorithm (classification, regression, etc.), feature type (continuous,
+ * categorical), depth of the tree, quantile calculation strategy, etc.
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
+ new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ }
+
+ /**
+ * Method to train a decision tree model where the instances are represented as an RDD of
+ * (label, features) pairs. The method supports binary classification and regression. For the
+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
+ * classes.
+ *
+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+ * training data
+ * @param algo algorithm, classification or regression
+ * @param impurity impurity criterion used for information gain calculation
+ * @param maxDepth maxDepth maximum depth of the tree
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int): DecisionTreeModel = {
+ val strategy = new Strategy(algo,impurity,maxDepth)
+ new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ }
+
+
+ /**
+ * Method to train a decision tree model where the instances are represented as an RDD of
+ * (label, features) pairs. The decision tree method supports binary classification and
+ * regression. For the binary classification, the label for each instance should either be 0 or
+ * 1 to denote the two classes. The method also supports categorical features inputs where the
+ * number of categories can specified using the categoricalFeaturesInfo option.
+ *
+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+ * training data for DecisionTree
+ * @param algo classification or regression
+ * @param impurity criterion used for information gain calculation
+ * @param maxDepth maximum depth of the tree
+ * @param maxBins maximum number of bins used for splitting features
+ * @param quantileCalculationStrategy algorithm for calculating quantiles
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+ * the number of discrete values they take. For example,
+ * an entry (n -> k) implies the feature n is categorical with k
+ * categories 0, 1, 2, ... , k-1. It's important to note that
+ * features are zero-indexed.
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ maxBins: Int,
+ quantileCalculationStrategy: QuantileStrategy,
+ categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
+ val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
+ categoricalFeaturesInfo)
+ new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ }
+
+ private val InvalidBinIndex = -1
+
+ /**
+ * Returns an array of optimal splits for all nodes at a given level
+ *
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * for DecisionTree
+ * @param parentImpurities Impurities for all parent nodes for the current level
+ * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
+ * parameters for construction the DecisionTree
+ * @param level Level of the tree
+ * @param filters Filters for all nodes at a given level
+ * @param splits possible splits for all features
+ * @param bins possible bins for all features
+ * @return array of splits with best splits for all nodes at a given level.
+ */
+ protected[tree] def findBestSplits(
+ input: RDD[LabeledPoint],
+ parentImpurities: Array[Double],
+ strategy: Strategy,
+ level: Int,
+ filters: Array[List[Filter]],
+ splits: Array[Array[Split]],
+ bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = {
+
+ /*
+ * The high-level description for the best split optimizations are noted here.
+ *
+ * *Level-wise training*
+ * We perform bin calculations for all nodes at the given level to avoid making multiple
+ * passes over the data. Thus, for a slightly increased computation and storage cost we save
+ * several iterations over the data especially at higher levels of the decision tree.
+ *
+ * *Bin-wise computation*
+ * We use a bin-wise best split computation strategy instead of a straightforward best split
+ * computation strategy. Instead of analyzing each sample for contribution to the left/right
+ * child node impurity of every split, we first categorize each feature of a sample into a
+ * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin,
+ * is ordered (read ordering for categorical variables in the findSplitsBins method),
+ * we exploit this structure to calculate aggregates for bins and then use these aggregates
+ * to calculate information gain for each split.
+ *
+ * *Aggregation over partitions*
+ * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
+ * the number of splits in advance. Thus, we store the aggregates (at the appropriate
+ * indices) in a single array for all bins and rely upon the RDD aggregate method to
+ * drastically reduce the communication overhead.
+ */
+
+ // common calculations for multiple nested methods
+ val numNodes = scala.math.pow(2, level).toInt
+ logDebug("numNodes = " + numNodes)
+ // Find the number of features by looking at the first sample.
+ val numFeatures = input.first().features.size
+ logDebug("numFeatures = " + numFeatures)
+ val numBins = bins(0).length
+ logDebug("numBins = " + numBins)
+
+ /** Find the filters used before reaching the current code. */
+ def findParentFilters(nodeIndex: Int): List[Filter] = {
+ if (level == 0) {
+ List[Filter]()
+ } else {
+ val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
+ filters(nodeFilterIndex)
+ }
+ }
+
+ /**
+ * Find whether the sample is valid input for the current node, i.e., whether it passes through
+ * all the filters for the current node.
+ */
+ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
+ // leaf
+ if ((level > 0) & (parentFilters.length == 0)) {
+ return false
+ }
+
+ // Apply each filter and check sample validity. Return false when invalid condition found.
+ for (filter <- parentFilters) {
+ val features = labeledPoint.features
+ val featureIndex = filter.split.feature
+ val threshold = filter.split.threshold
+ val comparison = filter.comparison
+ val categories = filter.split.categories
+ val isFeatureContinuous = filter.split.featureType == Continuous
+ val feature = features(featureIndex)
+ if (isFeatureContinuous) {
+ comparison match {
+ case -1 => if (feature > threshold) return false
+ case 1 => if (feature <= threshold) return false
+ }
+ } else {
+ val containsFeature = categories.contains(feature)
+ comparison match {
+ case -1 => if (!containsFeature) return false
+ case 1 => if (containsFeature) return false
+ }
+
+ }
+ }
+
+ // Return true when the sample is valid for all filters.
+ true
+ }
+
+ /**
+ * Find bin for one feature.
+ */
+ def findBin(
+ featureIndex: Int,
+ labeledPoint: LabeledPoint,
+ isFeatureContinuous: Boolean): Int = {
+ val binForFeatures = bins(featureIndex)
+ val feature = labeledPoint.features(featureIndex)
+
+ /**
+ * Binary search helper method for continuous feature.
+ */
+ def binarySearchForBins(): Int = {
+ var left = 0
+ var right = binForFeatures.length - 1
+ while (left <= right) {
+ val mid = left + (right - left) / 2
+ val bin = binForFeatures(mid)
+ val lowThreshold = bin.lowSplit.threshold
+ val highThreshold = bin.highSplit.threshold
+ if ((lowThreshold < feature) & (highThreshold >= feature)){
+ return mid
+ }
+ else if (lowThreshold >= feature) {
+ right = mid - 1
+ }
+ else {
+ left = mid + 1
+ }
+ }
+ -1
+ }
+
+ /**
+ * Sequential search helper method to find bin for categorical feature.
+ */
+ def sequentialBinSearchForCategoricalFeature(): Int = {
+ val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
+ var binIndex = 0
+ while (binIndex < numCategoricalBins) {
+ val bin = bins(featureIndex)(binIndex)
+ val category = bin.category
+ val features = labeledPoint.features
+ if (category == features(featureIndex)) {
+ return binIndex
+ }
+ binIndex += 1
+ }
+ -1
+ }
+
+ if (isFeatureContinuous) {
+ // Perform binary search for finding bin for continuous features.
+ val binIndex = binarySearchForBins()
+ if (binIndex == -1){
+ throw new UnknownError("no bin was found for continuous variable.")
+ }
+ binIndex
+ } else {
+ // Perform sequential search to find bin for categorical features.
+ val binIndex = sequentialBinSearchForCategoricalFeature()
+ if (binIndex == -1){
+ throw new UnknownError("no bin was found for categorical variable.")
+ }
+ binIndex
+ }
+ }
+
+ /**
+ * Finds bins for all nodes (and all features) at a given level.
+ * For l nodes, k features the storage is as follows:
+ * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
+ * where b_ij is an integer between 0 and numBins - 1.
+ * Invalid sample is denoted by noting bin for feature 1 as -1.
+ */
+ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
+ // Calculate bin index and label per feature per node.
+ val arr = new Array[Double](1 + (numFeatures * numNodes))
+ arr(0) = labeledPoint.label
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ val parentFilters = findParentFilters(nodeIndex)
+ // Find out whether the sample qualifies for the particular node.
+ val sampleValid = isSampleValid(parentFilters, labeledPoint)
+ val shift = 1 + numFeatures * nodeIndex
+ if (!sampleValid) {
+ // Mark one bin as -1 is sufficient.
+ arr(shift) = InvalidBinIndex
+ } else {
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
+ featureIndex += 1
+ }
+ }
+ nodeIndex += 1
+ }
+ arr
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition for classification. For l nodes,
+ * k features, either the left count or the right count of one of the p bins is
+ * incremented based upon whether the feature is classified as 0 or 1.
+ *
+ * @param agg Array[Double] storing aggregate calculation of size
+ * 2 * numSplits * numFeatures*numNodes for classification
+ * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+ * @return Array[Double] storing aggregate calculation of size
+ * 2 * numSplits * numFeatures * numNodes for classification
+ */
+ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+ // Iterate over all nodes.
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ // Check whether the instance was valid for this nodeIndex.
+ val validSignalIndex = 1 + numFeatures * nodeIndex
+ val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+ if (isSampleValidForNode) {
+ // actual class label
+ val label = arr(0)
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // Find the bin index for this feature.
+ val arrShift = 1 + numFeatures * nodeIndex
+ val arrIndex = arrShift + featureIndex
+ // Update the left or right count for one bin.
+ val aggShift = 2 * numBins * numFeatures * nodeIndex
+ val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
+ label match {
+ case 0.0 => agg(aggIndex) = agg(aggIndex) + 1
+ case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
+ }
+ featureIndex += 1
+ }
+ }
+ nodeIndex += 1
+ }
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition for regression. For l nodes, k features,
+ * the count, sum, sum of squares of one of the p bins is incremented.
+ *
+ * @param agg Array[Double] storing aggregate calculation of size
+ * 3 * numSplits * numFeatures * numNodes for classification
+ * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+ * @return Array[Double] storing aggregate calculation of size
+ * 3 * numSplits * numFeatures * numNodes for regression
+ */
+ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+ // Iterate over all nodes.
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ // Check whether the instance was valid for this nodeIndex.
+ val validSignalIndex = 1 + numFeatures * nodeIndex
+ val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+ if (isSampleValidForNode) {
+ // actual class label
+ val label = arr(0)
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // Find the bin index for this feature.
+ val arrShift = 1 + numFeatures * nodeIndex
+ val arrIndex = arrShift + featureIndex
+ // Update count, sum, and sum^2 for one bin.
+ val aggShift = 3 * numBins * numFeatures * nodeIndex
+ val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
+ agg(aggIndex) = agg(aggIndex) + 1
+ agg(aggIndex + 1) = agg(aggIndex + 1) + label
+ agg(aggIndex + 2) = agg(aggIndex + 2) + label*label
+ featureIndex += 1
+ }
+ }
+ nodeIndex += 1
+ }
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition.
+ */
+ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
+ strategy.algo match {
+ case Classification => classificationBinSeqOp(arr, agg)
+ case Regression => regressionBinSeqOp(arr, agg)
+ }
+ agg
+ }
+
+ // Calculate bin aggregate length for classification or regression.
+ val binAggregateLength = strategy.algo match {
+ case Classification => 2 * numBins * numFeatures * numNodes
+ case Regression => 3 * numBins * numFeatures * numNodes
+ }
+ logDebug("binAggregateLength = " + binAggregateLength)
+
+ /**
+ * Combines the aggregates from partitions.
+ * @param agg1 Array containing aggregates from one or more partitions
+ * @param agg2 Array containing aggregates from one or more partitions
+ * @return Combined aggregate from agg1 and agg2
+ */
+ def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = {
+ var index = 0
+ val combinedAggregate = new Array[Double](binAggregateLength)
+ while (index < binAggregateLength) {
+ combinedAggregate(index) = agg1(index) + agg2(index)
+ index += 1
+ }
+ combinedAggregate
+ }
+
+ // Find feature bins for all nodes at a level.
+ val binMappedRDD = input.map(x => findBinsForLevel(x))
+
+ // Calculate bin aggregates.
+ val binAggregates = {
+ binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
+ }
+ logDebug("binAggregates.length = " + binAggregates.length)
+
+ /**
+ * Calculates the information gain for all splits based upon left/right split aggregates.
+ * @param leftNodeAgg left node aggregates
+ * @param featureIndex feature index
+ * @param splitIndex split index
+ * @param rightNodeAgg right node aggregate
+ * @param topImpurity impurity of the parent node
+ * @return information gain and statistics for all splits
+ */
+ def calculateGainForSplit(
+ leftNodeAgg: Array[Array[Double]],
+ featureIndex: Int,
+ splitIndex: Int,
+ rightNodeAgg: Array[Array[Double]],
+ topImpurity: Double): InformationGainStats = {
+ strategy.algo match {
+ case Classification =>
+ val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
+ val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
+ val leftCount = left0Count + left1Count
+
+ val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
+ val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
+ val rightCount = right0Count + right1Count
+
+ val impurity = {
+ if (level > 0) {
+ topImpurity
+ } else {
+ // Calculate impurity for root node.
+ strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
+ }
+ }
+
+ if (leftCount == 0) {
+ return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
+ }
+ if (rightCount == 0) {
+ return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
+ }
+
+ val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
+ val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
+
+ val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+ val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+
+ val gain = {
+ if (level > 0) {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ } else {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ }
+ }
+
+ val predict = (left1Count + right1Count) / (leftCount + rightCount)
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+ case Regression =>
+ val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
+ val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
+ val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
+
+ val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
+ val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
+ val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
+
+ val impurity = {
+ if (level > 0) {
+ topImpurity
+ } else {
+ // Calculate impurity for root node.
+ val count = leftCount + rightCount
+ val sum = leftSum + rightSum
+ val sumSquares = leftSumSquares + rightSumSquares
+ strategy.impurity.calculate(count, sum, sumSquares)
+ }
+ }
+
+ if (leftCount == 0) {
+ return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,
+ rightSum / rightCount)
+ }
+ if (rightCount == 0) {
+ return new InformationGainStats(0, topImpurity ,topImpurity,
+ Double.MinValue, leftSum / leftCount)
+ }
+
+ val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
+ val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
+
+ val leftWeight = leftCount.toDouble / (leftCount + rightCount)
+ val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+
+ val gain = {
+ if (level > 0) {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ } else {
+ impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+ }
+ }
+
+ val predict = (leftSum + rightSum) / (leftCount + rightCount)
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+ }
+ }
+
+ /**
+ * Extracts left and right split aggregates.
+ * @param binData Array[Double] of size 2*numFeatures*numSplits
+ * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
+ * Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
+ */
+ def extractLeftRightNodeAggregates(
+ binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
+ strategy.algo match {
+ case Classification =>
+ // Initialize left and right split aggregates.
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // shift for this featureIndex
+ val shift = 2 * featureIndex * numBins
+
+ // left node aggregate for the lowest split
+ leftNodeAgg(featureIndex)(0) = binData(shift + 0)
+ leftNodeAgg(featureIndex)(1) = binData(shift + 1)
+
+ // right node aggregate for the highest split
+ rightNodeAgg(featureIndex)(2 * (numBins - 2))
+ = binData(shift + (2 * (numBins - 1)))
+ rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
+ = binData(shift + (2 * (numBins - 1)) + 1)
+
+ // Iterate over all splits.
+ var splitIndex = 1
+ while (splitIndex < numBins - 1) {
+ // calculating left node aggregate for a split as a sum of left node aggregate of a
+ // lower split and the left bin aggregate of a bin where the split is a high split
+ leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
+ leftNodeAgg(featureIndex)(2 * splitIndex - 2)
+ leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
+ leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
+
+ // calculating right node aggregate for a split as a sum of right node aggregate of a
+ // higher split and the right bin aggregate of a bin where the split is a low split
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
+ binData(shift + (2 *(numBins - 2 - splitIndex))) +
+ rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
+ rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
+ binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
+ rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
+
+ splitIndex += 1
+ }
+ featureIndex += 1
+ }
+ (leftNodeAgg, rightNodeAgg)
+ case Regression =>
+ // Initialize left and right split aggregates.
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // shift for this featureIndex
+ val shift = 3 * featureIndex * numBins
+ // left node aggregate for the lowest split
+ leftNodeAgg(featureIndex)(0) = binData(shift + 0)
+ leftNodeAgg(featureIndex)(1) = binData(shift + 1)
+ leftNodeAgg(featureIndex)(2) = binData(shift + 2)
+
+ // right node aggregate for the highest split
+ rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
+ binData(shift + (3 * (numBins - 1)))
+ rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
+ binData(shift + (3 * (numBins - 1)) + 1)
+ rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
+ binData(shift + (3 * (numBins - 1)) + 2)
+
+ // Iterate over all splits.
+ var splitIndex = 1
+ while (splitIndex < numBins - 1) {
+ // calculating left node aggregate for a split as a sum of left node aggregate of a
+ // lower split and the left bin aggregate of a bin where the split is a high split
+ leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
+ leftNodeAgg(featureIndex)(3 * splitIndex - 3)
+ leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
+ leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
+ leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
+ leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
+
+ // calculating right node aggregate for a split as a sum of right node aggregate of a
+ // higher split and the right bin aggregate of a bin where the split is a low split
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
+ binData(shift + (3 * (numBins - 2 - splitIndex))) +
+ rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
+ binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
+ rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
+ rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
+ binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
+ rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
+
+ splitIndex += 1
+ }
+ featureIndex += 1
+ }
+ (leftNodeAgg, rightNodeAgg)
+ }
+ }
+
+ /**
+ * Calculates information gain for all nodes splits.
+ */
+ def calculateGainsForAllNodeSplits(
+ leftNodeAgg: Array[Array[Double]],
+ rightNodeAgg: Array[Array[Double]],
+ nodeImpurity: Double): Array[Array[InformationGainStats]] = {
+ val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
+
+ for (featureIndex <- 0 until numFeatures) {
+ for (splitIndex <- 0 until numBins - 1) {
+ gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
+ splitIndex, rightNodeAgg, nodeImpurity)
+ }
+ }
+ gains
+ }
+
+ /**
+ * Find the best split for a node.
+ * @param binData Array[Double] of size 2 * numSplits * numFeatures
+ * @param nodeImpurity impurity of the top node
+ * @return tuple of split and information gain
+ */
+ def binsToBestSplit(
+ binData: Array[Double],
+ nodeImpurity: Double): (Split, InformationGainStats) = {
+
+ logDebug("node impurity = " + nodeImpurity)
+
+ // Extract left right node aggregates.
+ val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
+
+ // Calculate gains for all splits.
+ val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
+
+ val (bestFeatureIndex,bestSplitIndex, gainStats) = {
+ // Initialize with infeasible values.
+ var bestFeatureIndex = Int.MinValue
+ var bestSplitIndex = Int.MinValue
+ var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0)
+ // Iterate over features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ // Iterate over all splits.
+ var splitIndex = 0
+ while (splitIndex < numBins - 1) {
+ val gainStats = gains(featureIndex)(splitIndex)
+ if (gainStats.gain > bestGainStats.gain) {
+ bestGainStats = gainStats
+ bestFeatureIndex = featureIndex
+ bestSplitIndex = splitIndex
+ }
+ splitIndex += 1
+ }
+ featureIndex += 1
+ }
+ (bestFeatureIndex, bestSplitIndex, bestGainStats)
+ }
+
+ logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
+ logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
+
+ (splits(bestFeatureIndex)(bestSplitIndex), gainStats)
+ }
+
+ /**
+ * Get bin data for one node.
+ */
+ def getBinDataForNode(node: Int): Array[Double] = {
+ strategy.algo match {
+ case Classification =>
+ val shift = 2 * node * numBins * numFeatures
+ val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
+ binsForNode
+ case Regression =>
+ val shift = 3 * node * numBins * numFeatures
+ val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
+ binsForNode
+ }
+ }
+
+ // Calculate best splits for all nodes at a given level
+ val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+ // Iterating over all nodes at this level
+ var node = 0
+ while (node < numNodes) {
+ val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node
+ val binsForNode: Array[Double] = getBinDataForNode(node)
+ logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
+ val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
+ logDebug("node impurity = " + parentNodeImpurity)
+ bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
+ node += 1
+ }
+
+ bestSplits
+ }
+
+ /**
+ * Returns split and bins for decision tree calculation.
+ * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
+ * for DecisionTree
+ * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
+ * parameters for construction the DecisionTree
+ * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree
+ * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache
+ * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
+ */
+ protected[tree] def findSplitsBins(
+ input: RDD[LabeledPoint],
+ strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
+ val count = input.count()
+
+ // Find the number of features by looking at the first sample
+ val numFeatures = input.take(1)(0).features.size
+
+ val maxBins = strategy.maxBins
+ val numBins = if (maxBins <= count) maxBins else count.toInt
+ logDebug("numBins = " + numBins)
+
+ /*
+ * TODO: Add a require statement ensuring #bins is always greater than the categories.
+ * It's a limitation of the current implementation but a reasonable trade-off since features
+ * with large number of categories get favored over continuous features.
+ */
+ if (strategy.categoricalFeaturesInfo.size > 0) {
+ val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
+ require(numBins >= maxCategoriesForFeatures)
+ }
+
+ // Calculate the number of sample for approximate quantile calculation.
+ val requiredSamples = numBins*numBins
+ val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
+ logDebug("fraction of data used for calculating quantiles = " + fraction)
+
+ // sampled input for RDD calculation
+ val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect()
+ val numSamples = sampledInput.length
+
+ val stride: Double = numSamples.toDouble / numBins
+ logDebug("stride = " + stride)
+
+ strategy.quantileCalculationStrategy match {
+ case Sort =>
+ val splits = Array.ofDim[Split](numFeatures, numBins - 1)
+ val bins = Array.ofDim[Bin](numFeatures, numBins)
+
+ // Find all splits.
+
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures){
+ // Check whether the feature is continuous.
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) {
+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
+ val stride: Double = numSamples.toDouble / numBins
+ logDebug("stride = " + stride)
+ for (index <- 0 until numBins - 1) {
+ val sampleIndex = (index + 1) * stride.toInt
+ val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
+ splits(featureIndex)(index) = split
+ }
+ } else {
+ val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
+ require(maxFeatureValue < numBins, "number of categories should be less than number " +
+ "of bins")
+
+ // For categorical variables, each bin is a category. The bins are sorted and they
+ // are ordered by calculating the centroid of their corresponding labels.
+ val centroidForCategories =
+ sampledInput.map(lp => (lp.features(featureIndex),lp.label))
+ .groupBy(_._1)
+ .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
+
+ // Check for missing categorical variables and putting them last in the sorted list.
+ val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
+ for (i <- 0 until maxFeatureValue) {
+ if (centroidForCategories.contains(i)) {
+ fullCentroidForCategories(i) = centroidForCategories(i)
+ } else {
+ fullCentroidForCategories(i) = Double.MaxValue
+ }
+ }
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
+
+ logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
+
+ var categoriesForSplit = List[Double]()
+ categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
+ case ((key, value), index) =>
+ categoriesForSplit = key :: categoriesForSplit
+ splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
+ categoriesForSplit)
+ bins(featureIndex)(index) = {
+ if (index == 0) {
+ new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
+ splits(featureIndex)(0), Categorical, key)
+ } else {
+ new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+ Categorical, key)
+ }
+ }
+ }
+ }
+ featureIndex += 1
+ }
+
+ // Find all bins.
+ featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
+ bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
+ splits(featureIndex)(0), Continuous, Double.MinValue)
+ for (index <- 1 until numBins - 1){
+ val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+ Continuous, Double.MinValue)
+ bins(featureIndex)(index) = bin
+ }
+ bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2),
+ new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
+ }
+ featureIndex += 1
+ }
+ (splits,bins)
+ case MinMax =>
+ throw new UnsupportedOperationException("minmax not supported yet.")
+ case ApproxHist =>
+ throw new UnsupportedOperationException("approximate histogram not supported yet.")
+ }
+ }
+
+ val usage = """
+ Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num]
+ """
+
+ def main(args: Array[String]) {
+
+ if (args.length < 2) {
+ System.err.println(usage)
+ System.exit(1)
+ }
+
+ val sc = new SparkContext(args(0), "DecisionTree")
+
+ val argList = args.toList.drop(1)
+ type OptionMap = Map[Symbol, Any]
+
+ def nextOption(map : OptionMap, list: List[String]): OptionMap = {
+ list match {
+ case Nil => map
+ case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail)
+ case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail)
+ case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail)
+ case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail)
+ case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string)
+ , tail)
+ case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string),
+ tail)
+ case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail)
+ case option :: tail => logError("Unknown option " + option)
+ sys.exit(1)
+ }
+ }
+ val options = nextOption(Map(), argList)
+ logDebug(options.toString())
+
+ // Load training data.
+ val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString)
+
+ // Identify the type of algorithm.
+ val algoStr = options.get('algo).get.toString
+ val algo = algoStr match {
+ case "Classification" => Classification
+ case "Regression" => Regression
+ }
+
+ // Identify the type of impurity.
+ val impurityStr = options.getOrElse('impurity,
+ if (algo == Classification) "Gini" else "Variance").toString
+ val impurity = impurityStr match {
+ case "Gini" => Gini
+ case "Entropy" => Entropy
+ case "Variance" => Variance
+ }
+
+ val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt
+ val maxBins = options.getOrElse('maxBins, "100").toString.toInt
+
+ val strategy = new Strategy(algo, impurity, maxDepth, maxBins)
+ val model = DecisionTree.train(trainData, strategy)
+
+ // Load test data.
+ val testData = loadLabeledData(sc, options.get('testDataDir).get.toString)
+
+ // Measure algorithm accuracy
+ if (algo == Classification) {
+ val accuracy = accuracyScore(model, testData)
+ logDebug("accuracy = " + accuracy)
+ }
+
+ if (algo == Regression) {
+ val mse = meanSquaredError(model, testData)
+ logDebug("mean square error = " + mse)
+ }
+
+ sc.stop()
+ }
+
+ /**
+ * Load labeled data from a file. The data format used here is
+ * , ...,
+ * where , are feature values in Double and is the corresponding label as Double.
+ *
+ * @param sc SparkContext
+ * @param dir Directory to the input data files.
+ * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is
+ * the label, and the second element represents the feature values (an array of Double).
+ */
+ def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
+ sc.textFile(dir).map { line =>
+ val parts = line.trim().split(",")
+ val label = parts(0).toDouble
+ val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble))
+ LabeledPoint(label, features)
+ }
+ }
+
+ // TODO: Port this method to a generic metrics package.
+ /**
+ * Calculates the classifier accuracy.
+ */
+ private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint],
+ threshold: Double = 0.5): Double = {
+ def predictedValue(features: Vector) = {
+ if (model.predict(features) < threshold) 0.0 else 1.0
+ }
+ val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
+ val count = data.count()
+ logDebug("correct prediction count = " + correctCount)
+ logDebug("data count = " + count)
+ correctCount.toDouble / count
+ }
+
+ // TODO: Port this method to a generic metrics package
+ /**
+ * Calculates the mean squared error for regression.
+ */
+ private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = tree.predict(y.features) - y.label
+ err * err
+ }.mean()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
new file mode 100644
index 0000000000000..0fd71aa9735bc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
@@ -0,0 +1,17 @@
+This package contains the default implementation of the decision tree algorithm.
+
+The decision tree algorithm supports:
++ Binary classification
++ Regression
++ Information loss calculation with entropy and gini for classification and variance for regression
++ Both continuous and categorical features
+
+# Tree improvements
++ Node model pruning
++ Printing to dot files
+
+# Future Ensemble Extensions
+
++ Random forests
++ Boosting
++ Extremely randomized trees
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
new file mode 100644
index 0000000000000..2dd1f0f27b8f5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum to select the algorithm for the decision tree
+ */
+object Algo extends Enumeration {
+ type Algo = Value
+ val Classification, Regression = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
new file mode 100644
index 0000000000000..09ee0586c58fa
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum to describe whether a feature is "continuous" or "categorical"
+ */
+object FeatureType extends Enumeration {
+ type FeatureType = Value
+ val Continuous, Categorical = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
new file mode 100644
index 0000000000000..2457a480c2a14
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum for selecting the quantile calculation strategy
+ */
+object QuantileStrategy extends Enumeration {
+ type QuantileStrategy = Value
+ val Sort, MinMax, ApproxHist = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
new file mode 100644
index 0000000000000..df565f3eb8859
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+
+/**
+ * Stores all the configuration options for tree construction
+ * @param algo classification or regression
+ * @param impurity criterion used for information gain calculation
+ * @param maxDepth maximum depth of the tree
+ * @param maxBins maximum number of bins used for splitting features
+ * @param quantileCalculationStrategy algorithm for calculating quantiles
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
+ * number of discrete values they take. For example, an entry (n ->
+ * k) implies the feature n is categorical with k categories 0,
+ * 1, 2, ... , k-1. It's important to note that features are
+ * zero-indexed.
+ */
+class Strategy (
+ val algo: Algo,
+ val impurity: Impurity,
+ val maxDepth: Int,
+ val maxBins: Int = 100,
+ val quantileCalculationStrategy: QuantileStrategy = Sort,
+ val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
new file mode 100644
index 0000000000000..b93995fcf9441
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
+ * binary classification.
+ */
+object Entropy extends Impurity {
+
+ def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
+
+ /**
+ * entropy calculation
+ * @param c0 count of instances with label 0
+ * @param c1 count of instances with label 1
+ * @return entropy value
+ */
+ def calculate(c0: Double, c1: Double): Double = {
+ if (c0 == 0 || c1 == 0) {
+ 0
+ } else {
+ val total = c0 + c1
+ val f0 = c0 / total
+ val f1 = c1 / total
+ -(f0 * log2(f0)) - (f1 * log2(f1))
+ }
+ }
+
+ def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+ throw new UnsupportedOperationException("Entropy.calculate")
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
new file mode 100644
index 0000000000000..c0407554a91b3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating the
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
+ * during binary classification.
+ */
+object Gini extends Impurity {
+
+ /**
+ * Gini coefficient calculation
+ * @param c0 count of instances with label 0
+ * @param c1 count of instances with label 1
+ * @return Gini coefficient value
+ */
+ override def calculate(c0: Double, c1: Double): Double = {
+ if (c0 == 0 || c1 == 0) {
+ 0
+ } else {
+ val total = c0 + c1
+ val f0 = c0 / total
+ val f1 = c1 / total
+ 1 - f0 * f0 - f1 * f1
+ }
+ }
+
+ def calculate(count: Double, sum: Double, sumSquares: Double): Double =
+ throw new UnsupportedOperationException("Gini.calculate")
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
new file mode 100644
index 0000000000000..a4069063af2ad
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Trait for calculating information gain.
+ */
+trait Impurity extends Serializable {
+
+ /**
+ * information calculation for binary classification
+ * @param c0 count of instances with label 0
+ * @param c1 count of instances with label 1
+ * @return information value
+ */
+ def calculate(c0 : Double, c1 : Double): Double
+
+ /**
+ * information calculation for regression
+ * @param count number of instances
+ * @param sum sum of labels
+ * @param sumSquares summation of squares of the labels
+ * @return information value
+ */
+ def calculate(count: Double, sum: Double, sumSquares: Double): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
new file mode 100644
index 0000000000000..b74577dcec167
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Class for calculating variance during regression
+ */
+object Variance extends Impurity {
+ override def calculate(c0: Double, c1: Double): Double =
+ throw new UnsupportedOperationException("Variance.calculate")
+
+ /**
+ * variance calculation
+ * @param count number of instances
+ * @param sum sum of labels
+ * @param sumSquares summation of squares of the labels
+ */
+ override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
+ val squaredLoss = sumSquares - (sum * sum) / count
+ squaredLoss / count
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
new file mode 100644
index 0000000000000..a57faa13745f7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+
+/**
+ * Used for "binning" the features bins for faster best split calculation. For a continuous
+ * feature, a bin is determined by a low and a high "split". For a categorical feature,
+ * the a bin is determined using a single label value (category).
+ * @param lowSplit signifying the lower threshold for the continuous feature to be
+ * accepted in the bin
+ * @param highSplit signifying the upper threshold for the continuous feature to be
+ * accepted in the bin
+ * @param featureType type of feature -- categorical or continuous
+ * @param category categorical label value accepted in the bin
+ */
+case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
new file mode 100644
index 0000000000000..a6dca84a2ce09
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vector
+
+/**
+ * Model to store the decision tree parameters
+ * @param topNode root node
+ * @param algo algorithm type -- classification or regression
+ */
+class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ def predict(features: Vector): Double = {
+ topNode.predictIfLeaf(features)
+ }
+
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Int] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = {
+ features.map(x => predict(x))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
new file mode 100644
index 0000000000000..ebc9595eafef3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+/**
+ * Filter specifying a split and type of comparison to be applied on features
+ * @param split split specifying the feature index, type and threshold
+ * @param comparison integer specifying <,=,>
+ */
+case class Filter(split: Split, comparison: Int) {
+ // Comparison -1,0,1 signifies <.=,>
+ override def toString = " split = " + split + "comparison = " + comparison
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
new file mode 100644
index 0000000000000..99bf79cf12e45
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+/**
+ * Information gain statistics for each split
+ * @param gain information gain value
+ * @param impurity current node impurity
+ * @param leftImpurity left node impurity
+ * @param rightImpurity right node impurity
+ * @param predict predicted value
+ */
+class InformationGainStats(
+ val gain: Double,
+ val impurity: Double,
+ val leftImpurity: Double,
+ val rightImpurity: Double,
+ val predict: Double) extends Serializable {
+
+ override def toString = {
+ "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
+ .format(gain, impurity, leftImpurity, rightImpurity, predict)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
new file mode 100644
index 0000000000000..aac3f9ce308f7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.Logging
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.linalg.Vector
+
+/**
+ * Node in a decision tree
+ * @param id integer node id
+ * @param predict predicted value at the node
+ * @param isLeaf whether the leaf is a node
+ * @param split split to calculate left and right nodes
+ * @param leftNode left child
+ * @param rightNode right child
+ * @param stats information gain stats
+ */
+class Node (
+ val id: Int,
+ val predict: Double,
+ val isLeaf: Boolean,
+ val split: Option[Split],
+ var leftNode: Option[Node],
+ var rightNode: Option[Node],
+ val stats: Option[InformationGainStats]) extends Serializable with Logging {
+
+ override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
+ "split = " + split + ", stats = " + stats
+
+ /**
+ * build the left node and right nodes if not leaf
+ * @param nodes array of nodes
+ */
+ def build(nodes: Array[Node]): Unit = {
+
+ logDebug("building node " + id + " at level " +
+ (scala.math.log(id + 1)/scala.math.log(2)).toInt )
+ logDebug("id = " + id + ", split = " + split)
+ logDebug("stats = " + stats)
+ logDebug("predict = " + predict)
+ if (!isLeaf) {
+ val leftNodeIndex = id * 2 + 1
+ val rightNodeIndex = id * 2 + 2
+ leftNode = Some(nodes(leftNodeIndex))
+ rightNode = Some(nodes(rightNodeIndex))
+ leftNode.get.build(nodes)
+ rightNode.get.build(nodes)
+ }
+ }
+
+ /**
+ * predict value if node is not leaf
+ * @param feature feature value
+ * @return predicted value
+ */
+ def predictIfLeaf(feature: Vector) : Double = {
+ if (isLeaf) {
+ predict
+ } else{
+ if (split.get.featureType == Continuous) {
+ if (feature(split.get.feature) <= split.get.threshold) {
+ leftNode.get.predictIfLeaf(feature)
+ } else {
+ rightNode.get.predictIfLeaf(feature)
+ }
+ } else {
+ if (split.get.categories.contains(feature(split.get.feature))) {
+ leftNode.get.predictIfLeaf(feature)
+ } else {
+ rightNode.get.predictIfLeaf(feature)
+ }
+ }
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
new file mode 100644
index 0000000000000..4e64a81dda74e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
+
+/**
+ * Split applied to a feature
+ * @param feature feature index
+ * @param threshold threshold for continuous feature
+ * @param featureType type of feature -- categorical or continuous
+ * @param categories accepted values for categorical variables
+ */
+case class Split(
+ feature: Int,
+ threshold: Double,
+ featureType: FeatureType,
+ categories: List[Double]){
+
+ override def toString =
+ "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
+ ", categories = " + categories
+}
+
+/**
+ * Split with minimum threshold for continuous features. Helps with the smallest bin creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyLowSplit(feature: Int, featureType: FeatureType)
+ extends Split(feature, Double.MinValue, featureType, List())
+
+/**
+ * Split with maximum threshold for continuous features. Helps with the highest bin creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyHighSplit(feature: Int, featureType: FeatureType)
+ extends Split(feature, Double.MaxValue, featureType, List())
+
+/**
+ * Split with no acceptable feature values for categorical features. Helps with the first bin
+ * creation.
+ * @param feature feature index
+ * @param featureType type of feature -- categorical or continuous
+ */
+class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
+ extends Split(feature, Double.MaxValue, featureType, List())
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index 2e03684e62861..81e4eda2a68c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -24,6 +24,7 @@ import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
/**
@@ -74,7 +75,7 @@ object LinearDataGenerator {
val y = x.map { xi =>
new DoubleMatrix(1, xi.length, xi: _*).dot(weightsMat) + intercept + eps * rnd.nextGaussian()
}
- y.zip(x).map(p => LabeledPoint(p._1, p._2))
+ y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
index 52c4a71d621a1..61498dcc2be00 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala
@@ -22,6 +22,7 @@ import scala.util.Random
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.linalg.Vectors
/**
* Generate test data for LogisticRegression. This class chooses positive labels
@@ -54,7 +55,7 @@ object LogisticRegressionDataGenerator {
val x = Array.fill[Double](nfeatures) {
rnd.nextGaussian() + (y * eps)
}
- LabeledPoint(y, x)
+ LabeledPoint(y, Vectors.dense(x))
}
data
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 08cd9ab05547b..cb85e433bfc73 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -17,15 +17,13 @@
package org.apache.spark.mllib.util
+import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
+ squaredDistance => breezeSquaredDistance}
+
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
-
-import org.jblas.DoubleMatrix
-
import org.apache.spark.mllib.regression.LabeledPoint
-
-import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
* Helper methods to load, save and pre-process data used in ML Lib.
@@ -40,6 +38,107 @@ object MLUtils {
eps
}
+ /**
+ * Multiclass label parser, which parses a string into double.
+ */
+ val multiclassLabelParser: String => Double = _.toDouble
+
+ /**
+ * Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5,
+ * or 0.0 (negative) otherwise.
+ */
+ val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0
+
+ /**
+ * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
+ * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
+ * Each line represents a labeled sparse feature vector using the following format:
+ * {{{label index1:value1 index2:value2 ...}}}
+ * where the indices are one-based and in ascending order.
+ * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]],
+ * where the feature indices are converted to zero-based.
+ *
+ * @param sc Spark context
+ * @param path file or directory path in any Hadoop-supported file system URI
+ * @param labelParser parser for labels, default: 1.0 if label > 0.5 or 0.0 otherwise
+ * @param numFeatures number of features, which will be determined from the input data if a
+ * negative value is given. The default value is -1.
+ * @param minSplits min number of partitions, default: sc.defaultMinSplits
+ * @return labeled data stored as an RDD[LabeledPoint]
+ */
+ def loadLibSVMData(
+ sc: SparkContext,
+ path: String,
+ labelParser: String => Double,
+ numFeatures: Int,
+ minSplits: Int): RDD[LabeledPoint] = {
+ val parsed = sc.textFile(path, minSplits)
+ .map(_.trim)
+ .filter(!_.isEmpty)
+ .map(_.split(' '))
+ // Determine number of features.
+ val d = if (numFeatures >= 0) {
+ numFeatures
+ } else {
+ parsed.map { items =>
+ if (items.length > 1) {
+ items.last.split(':')(0).toInt
+ } else {
+ 0
+ }
+ }.reduce(math.max)
+ }
+ parsed.map { items =>
+ val label = labelParser(items.head)
+ val (indices, values) = items.tail.map { item =>
+ val indexAndValue = item.split(':')
+ val index = indexAndValue(0).toInt - 1
+ val value = indexAndValue(1).toDouble
+ (index, value)
+ }.unzip
+ LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray))
+ }
+ }
+
+ // Convenient methods for calling from Java.
+
+ /**
+ * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
+ * with number of features determined automatically and the default number of partitions.
+ */
+ def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
+ loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits)
+
+ /**
+ * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
+ * with number of features specified explicitly and the default number of partitions.
+ */
+ def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] =
+ loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits)
+
+ /**
+ * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
+ * with the given label parser, number of features determined automatically,
+ * and the default number of partitions.
+ */
+ def loadLibSVMData(
+ sc: SparkContext,
+ path: String,
+ labelParser: String => Double): RDD[LabeledPoint] =
+ loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits)
+
+ /**
+ * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
+ * with the given label parser, number of features specified explicitly,
+ * and the default number of partitions.
+ */
+ def loadLibSVMData(
+ sc: SparkContext,
+ path: String,
+ labelParser: String => Double,
+ numFeatures: Int): RDD[LabeledPoint] =
+ loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)
+
/**
* Load labeled data from a file. The data format used here is
* , ...
@@ -54,7 +153,7 @@ object MLUtils {
sc.textFile(dir).map { line =>
val parts = line.split(',')
val label = parts(0).toDouble
- val features = parts(1).trim().split(' ').map(_.toDouble)
+ val features = Vectors.dense(parts(1).trim().split(' ').map(_.toDouble))
LabeledPoint(label, features)
}
}
@@ -68,7 +167,7 @@ object MLUtils {
* @param dir Directory to save the data.
*/
def saveLabeledData(data: RDD[LabeledPoint], dir: String) {
- val dataStr = data.map(x => x.label + "," + x.features.mkString(" "))
+ val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" "))
dataStr.saveAsTextFile(dir)
}
@@ -76,44 +175,52 @@ object MLUtils {
* Utility function to compute mean and standard deviation on a given dataset.
*
* @param data - input data set whose statistics are computed
- * @param nfeatures - number of features
- * @param nexamples - number of examples in input dataset
+ * @param numFeatures - number of features
+ * @param numExamples - number of examples in input dataset
*
* @return (yMean, xColMean, xColSd) - Tuple consisting of
* yMean - mean of the labels
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
- def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long):
- (Double, DoubleMatrix, DoubleMatrix) = {
- val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples
-
- // NOTE: We shuffle X by column here to compute column sum and sum of squares.
- val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint =>
- val nCols = labeledPoint.features.length
- // Traverse over every column and emit (col, value, value^2)
- Iterator.tabulate(nCols) { i =>
- (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i)))
- }
- }.reduceByKey { case(x1, x2) =>
- (x1._1 + x2._1, x1._2 + x2._2)
+ def computeStats(
+ data: RDD[LabeledPoint],
+ numFeatures: Int,
+ numExamples: Long): (Double, Vector, Vector) = {
+ val brzData = data.map { case LabeledPoint(label, features) =>
+ (label, features.toBreeze)
}
- val xColSumsMap = xColSumSq.collectAsMap()
-
- val xColMean = DoubleMatrix.zeros(nfeatures, 1)
- val xColSd = DoubleMatrix.zeros(nfeatures, 1)
-
- // Compute mean and unbiased variance using column sums
- var col = 0
- while (col < nfeatures) {
- xColMean.put(col, xColSumsMap(col)._1 / nexamples)
- val variance =
- (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / nexamples
- xColSd.put(col, math.sqrt(variance))
- col += 1
+ val aggStats = brzData.aggregate(
+ (0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures))
+ )(
+ seqOp = (c, v) => (c, v) match {
+ case ((n, sumLabel, sum, sumSq), (label, features)) =>
+ features.activeIterator.foreach { case (i, x) =>
+ sumSq(i) += x * x
+ }
+ (n + 1L, sumLabel + label, sum += features, sumSq)
+ },
+ combOp = (c1, c2) => (c1, c2) match {
+ case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) =>
+ (n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2)
+ }
+ )
+ val (nl, sumLabel, sum, sumSq) = aggStats
+
+ require(nl > 0, "Input data is empty.")
+ require(nl == numExamples)
+
+ val n = nl.toDouble
+ val yMean = sumLabel / n
+ val mean = sum / n
+ val std = new Array[Double](sum.length)
+ var i = 0
+ while (i < numFeatures) {
+ std(i) = sumSq(i) / n - mean(i) * mean(i)
+ i += 1
}
- (yMean, xColMean, xColSd)
+ (yMean, Vectors.fromBreeze(mean), Vectors.dense(std))
}
/**
@@ -144,6 +251,18 @@ object MLUtils {
val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
val normDiff = norm1 - norm2
var sqDist = 0.0
+ /*
+ * The relative error is
+ *
+ * EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
+ *
+ * which is bounded by
+ *
+ * 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
+ *
+ * The bound doesn't need the inner product, so we can use it as a sufficient condition to
+ * check quickly whether the inner product approach is accurate.
+ */
val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
if (precisionBound1 < precision) {
sqDist = sumSquaredNorm - 2.0 * v1.dot(v2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
index c96c94f70eef7..e300c3dbe1fe0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala
@@ -23,6 +23,7 @@ import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
/**
@@ -58,7 +59,7 @@ object SVMDataGenerator {
}
val yD = new DoubleMatrix(1, x.length, x: _*).dot(trueWeights) + rnd.nextGaussian() * 0.1
val y = if (yD < 0) 0.0 else 1.0
- LabeledPoint(y, x)
+ LabeledPoint(y, Vectors.dense(x))
}
MLUtils.saveLabeledData(data, outputPath)
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 073ded6f36933..c80b1134ed1b2 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -19,6 +19,7 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
import org.junit.Assert;
@@ -45,12 +46,12 @@ public void tearDown() {
}
private static final List POINTS = Arrays.asList(
- new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
- new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
- new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
- new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
- new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
- new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
+ new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)),
+ new LabeledPoint(0, Vectors.dense(2.0, 0.0, 0.0)),
+ new LabeledPoint(1, Vectors.dense(0.0, 1.0, 0.0)),
+ new LabeledPoint(1, Vectors.dense(0.0, 2.0, 0.0)),
+ new LabeledPoint(2, Vectors.dense(0.0, 0.0, 1.0)),
+ new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0))
);
private int validatePrediction(List points, NaiveBayesModel model) {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 117e5eaa8b78e..4701a5e545020 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.classification;
-
import java.io.Serializable;
import java.util.List;
@@ -28,7 +27,6 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-
import org.apache.spark.mllib.regression.LabeledPoint;
public class JavaSVMSuite implements Serializable {
@@ -94,5 +92,4 @@ public void runSVMUsingStaticMethods() {
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
-
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 2c4d795f96e4e..c6d8425ffc38d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -19,10 +19,10 @@
import java.io.Serializable;
-import com.google.common.collect.Lists;
-
import scala.Tuple2;
+import com.google.common.collect.Lists;
+
import org.junit.Test;
import static org.junit.Assert.*;
@@ -36,7 +36,7 @@ public void denseArrayConstruction() {
@Test
public void sparseArrayConstruction() {
- Vector v = Vectors.sparse(3, Lists.newArrayList(
+ Vector v = Vectors.sparse(3, Lists.>newArrayList(
new Tuple2(0, 2.0),
new Tuple2(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
index f44b25cd44d19..f725924a2d971 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
@@ -59,7 +59,7 @@ int validatePrediction(List validationData, LassoModel model) {
@Test
public void runLassoUsingConstructor() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
@@ -80,7 +80,7 @@ public void runLassoUsingConstructor() {
@Test
public void runLassoUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
index 2fdd5fc8fdca6..03714ae7e4d00 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -55,30 +55,27 @@ public void tearDown() {
return errorSum / validationData.size();
}
- List generateRidgeData(int numPoints, int nfeatures, double eps) {
+ List generateRidgeData(int numPoints, int numFeatures, double std) {
org.jblas.util.Random.seed(42);
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
- DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
- // Set first two weights to eps
- w.put(0, 0, eps);
- w.put(1, 0, eps);
- return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
+ DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5);
+ return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std);
}
@Test
public void runRidgeRegressionUsingConstructor() {
- int nexamples = 200;
- int nfeatures = 20;
- double eps = 10.0;
- List data = generateRidgeData(2*nexamples, nfeatures, eps);
+ int numExamples = 50;
+ int numFeatures = 20;
+ List data = generateRidgeData(2*numExamples, numFeatures, 10.0);
- JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples));
- List validationData = data.subList(nexamples, 2*nexamples);
+ JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples));
+ List validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
- ridgeSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.0)
- .setNumIterations(200);
+ ridgeSGDImpl.optimizer()
+ .setStepSize(1.0)
+ .setRegParam(0.0)
+ .setNumIterations(200);
RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
double unRegularizedErr = predictionError(validationData, model);
@@ -91,13 +88,12 @@ public void runRidgeRegressionUsingConstructor() {
@Test
public void runRidgeRegressionUsingStaticMethods() {
- int nexamples = 200;
- int nfeatures = 20;
- double eps = 10.0;
- List data = generateRidgeData(2*nexamples, nfeatures, eps);
+ int numExamples = 50;
+ int numFeatures = 20;
+ List data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
- JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples));
- List validationData = data.subList(nexamples, 2*nexamples);
+ JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples));
+ List validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
double unRegularizedErr = predictionError(validationData, model);
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 05322b024d5f6..1e03c9df820b0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
@@ -61,7 +60,7 @@ object LogisticRegressionSuite {
if (yVal > 0) 1 else 0
}
- val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i))))
+ val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i)))))
testData
}
@@ -113,7 +112,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul
val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
val initialB = -1.0
- val initialWeights = Array(initialB)
+ val initialWeights = Vectors.dense(initialB)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 9dd6c79ee6ad8..516895d04222d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
import scala.util.Random
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.LocalSparkContext
@@ -54,7 +54,7 @@ object NaiveBayesSuite {
if (rnd.nextDouble() < _theta(y)(j)) 1 else 0
}
- LabeledPoint(y, xi)
+ LabeledPoint(y, Vectors.dense(xi))
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index bc7abb568a172..dfacbfeee6fb4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.classification
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
@@ -28,6 +27,7 @@ import org.jblas.DoubleMatrix
import org.apache.spark.SparkException
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.linalg.Vectors
object SVMSuite {
@@ -54,7 +54,7 @@ object SVMSuite {
intercept + 0.01 * rnd.nextGaussian()
if (yD < 0) 0.0 else 1.0
}
- y.zip(x).map(p => LabeledPoint(p._1, p._2))
+ y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
}
@@ -110,7 +110,7 @@ class SVMSuite extends FunSuite with LocalSparkContext {
val initialB = -1.0
val initialC = -1.0
- val initialWeights = Array(initialB,initialC)
+ val initialWeights = Vectors.dense(initialB, initialC)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
@@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext {
}
intercept[SparkException] {
- val model = SVMWithSGD.train(testRDDInvalid, 100)
+ SVMWithSGD.train(testRDDInvalid, 100)
}
// Turning off data validation should not throw an exception
- val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
+ new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 631d0e2ad9cdb..c4b433499a091 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -20,13 +20,12 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.linalg.Vectors
object GradientDescentSuite {
@@ -58,8 +57,7 @@ object GradientDescentSuite {
if (yVal > 0) 1 else 0
}
- val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i))))
- testData
+ (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(x1(i))))
}
}
@@ -83,11 +81,11 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
val data = testData.map { case LabeledPoint(label, features) =>
- label -> Array(1.0, features: _*)
+ label -> Vectors.dense(1.0, features.toArray: _*)
}
val dataRDD = sc.parallelize(data, 2).cache()
- val initialWeightsWithIntercept = Array(1.0, initialWeights: _*)
+ val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*)
val (_, loss) = GradientDescent.runMiniBatchSGD(
dataRDD,
@@ -113,13 +111,13 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42)
val data = testData.map { case LabeledPoint(label, features) =>
- label -> Array(1.0, features: _*)
+ label -> Vectors.dense(1.0, features.toArray: _*)
}
val dataRDD = sc.parallelize(data, 2).cache()
// Prepare non-zero weights
- val initialWeightsWithIntercept = Array(1.0, 0.5)
+ val initialWeightsWithIntercept = Vectors.dense(1.0, 0.5)
val regParam0 = 0
val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 64e4cbb860f61..6aad9eb84e13c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -17,11 +17,9 @@
package org.apache.spark.mllib.regression
-
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -36,29 +34,33 @@ class LassoSuite extends FunSuite with LocalSparkContext {
}
test("Lasso local random SGD") {
- val nPoints = 10000
+ val nPoints = 1000
val A = 2.0
val B = -1.5
val C = 1.0e-2
- val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
-
- val testRDD = sc.parallelize(testData, 2)
- testRDD.cache()
+ val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
+ val testRDD = sc.parallelize(testData, 2).cache()
val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
val model = ls.run(testRDD)
-
val weight0 = model.weights(0)
val weight1 = model.weights(1)
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+ val weight2 = model.weights(2)
+ assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
+ assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
+ assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -69,33 +71,39 @@ class LassoSuite extends FunSuite with LocalSparkContext {
}
test("Lasso local random SGD with initial weights") {
- val nPoints = 10000
+ val nPoints = 1000
val A = 2.0
val B = -1.5
val C = 1.0e-2
- val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
+ val initialA = -1.0
val initialB = -1.0
val initialC = -1.0
- val initialWeights = Array(initialB,initialC)
+ val initialWeights = Vectors.dense(initialA, initialB, initialC)
- val testRDD = sc.parallelize(testData, 2)
- testRDD.cache()
+ val testRDD = sc.parallelize(testData, 2).cache()
val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
val model = ls.run(testRDD, initialWeights)
-
val weight0 = model.weights(0)
val weight1 = model.weights(1)
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+ val weight2 = model.weights(2)
+ assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
+ assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
+ assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
val validationRDD = sc.parallelize(validationData,2)
// Test prediction on RDD.
@@ -104,4 +112,10 @@ class LassoSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("do not support intercept") {
+ intercept[UnsupportedOperationException] {
+ new LassoWithSGD().setIntercept(true)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 281f9df36ddb3..2f7d30708ce17 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.mllib.regression
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
@@ -41,11 +41,12 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
-
assert(model.intercept >= 2.5 && model.intercept <= 3.5)
- assert(model.weights.length === 2)
- assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
- assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+ val weights = model.weights
+ assert(weights.size === 2)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(1) >= 9.0 && weights(1) <= 11.0)
val validationData = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 17)
@@ -57,4 +58,67 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ // Test if we can correctly learn Y = 10*X1 + 10*X2
+ test("linear regression without intercept") {
+ val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 0.0, Array(10.0, 10.0), 100, 42), 2).cache()
+ val linReg = new LinearRegressionWithSGD().setIntercept(false)
+ linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
+
+ val model = linReg.run(testRDD)
+
+ assert(model.intercept === 0.0)
+
+ val weights = model.weights
+ assert(weights.size === 2)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(1) >= 9.0 && weights(1) <= 11.0)
+
+ val validationData = LinearDataGenerator.generateLinearInput(
+ 0.0, Array(10.0, 10.0), 100, 17)
+ val validationRDD = sc.parallelize(validationData, 2).cache()
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
+
+ // Test if we can correctly learn Y = 10*X1 + 10*X10000
+ test("sparse linear regression without intercept") {
+ val denseRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2)
+ val sparseRDD = denseRDD.map { case LabeledPoint(label, v) =>
+ val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
+ LabeledPoint(label, sv)
+ }.cache()
+ val linReg = new LinearRegressionWithSGD().setIntercept(false)
+ linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
+
+ val model = linReg.run(sparseRDD)
+
+ assert(model.intercept === 0.0)
+
+ val weights = model.weights
+ assert(weights.size === 10000)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(9999) >= 9.0 && weights(9999) <= 11.0)
+
+ val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
+ val sparseValidationData = validationData.map { case LabeledPoint(label, v) =>
+ val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
+ LabeledPoint(label, sv)
+ }
+ val sparseValidationRDD = sc.parallelize(sparseValidationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(
+ model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData)
+
+ // Test prediction on Array.
+ validatePrediction(
+ sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index 67dd06cc0f5eb..f66fc6ea6c1ec 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,14 +17,12 @@
package org.apache.spark.mllib.regression
+import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
-import org.scalatest.BeforeAndAfterAll
-import org.scalatest.FunSuite
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
-
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
@@ -33,22 +31,22 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
}.reduceLeft(_ + _) / predictions.size
}
- test("regularization with skewed weights") {
- val nexamples = 200
- val nfeatures = 20
- val eps = 10
+ test("ridge regression can help avoid overfitting") {
+
+ // For small number of examples and large variance of error distribution,
+ // ridge regression should give smaller generalization error that linear regression.
+
+ val numExamples = 50
+ val numFeatures = 20
org.jblas.util.Random.seed(42)
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
- val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
- // Set first two weights to eps
- w.put(0, 0, eps)
- w.put(1, 0, eps)
+ val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5)
// Use half of data for training and other half for validation
- val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps)
- val testData = data.take(nexamples)
- val validationData = data.takeRight(nexamples)
+ val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0)
+ val testData = data.take(numExamples)
+ val validationData = data.takeRight(numExamples)
val testRDD = sc.parallelize(testData, 2).cache()
val validationRDD = sc.parallelize(validationData, 2).cache()
@@ -70,8 +68,14 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
val ridgeErr = predictionError(
ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)
- // Ridge CV-error should be lower than linear regression
+ // Ridge validation error should be lower than linear regression.
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
+
+ test("do not support intercept") {
+ intercept[UnsupportedOperationException] {
+ new RidgeRegressionWithSGD().setIntercept(true)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
new file mode 100644
index 0000000000000..350130c914f26
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
+import org.apache.spark.mllib.tree.model.Filter
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.linalg.Vectors
+
+class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
+
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("split and bin calculation") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+ }
+
+ test("split and bin calculation for categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(bins.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2) === null)
+ }
+
+ test("split and bin calculations for categorical variables with no sample for one category") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+ // Check splits.
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(0)(1).categories.contains(0.0))
+
+ assert(splits(0)(2).feature === 0)
+ assert(splits(0)(2).threshold === Double.MinValue)
+ assert(splits(0)(2).featureType === Categorical)
+ assert(splits(0)(2).categories.length === 3)
+ assert(splits(0)(2).categories.contains(1.0))
+ assert(splits(0)(2).categories.contains(0.0))
+ assert(splits(0)(2).categories.contains(2.0))
+
+ assert(splits(0)(3) === null)
+
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 2)
+ assert(splits(1)(1).categories.contains(1.0))
+ assert(splits(1)(1).categories.contains(0.0))
+
+ assert(splits(1)(2).feature === 1)
+ assert(splits(1)(2).threshold === Double.MinValue)
+ assert(splits(1)(2).featureType === Categorical)
+ assert(splits(1)(2).categories.length === 3)
+ assert(splits(1)(2).categories.contains(1.0))
+ assert(splits(1)(2).categories.contains(0.0))
+ assert(splits(1)(2).categories.contains(2.0))
+
+ assert(splits(1)(3) === null)
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(1).category === 0.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(2).category === 2.0)
+ assert(bins(0)(2).lowSplit.categories.length === 2)
+ assert(bins(0)(2).lowSplit.categories.contains(1.0))
+ assert(bins(0)(2).lowSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.length === 3)
+ assert(bins(0)(2).highSplit.categories.contains(1.0))
+ assert(bins(0)(2).highSplit.categories.contains(0.0))
+ assert(bins(0)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(0)(3) === null)
+
+ assert(bins(1)(0).category === 0.0)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(1)(1).category === 1.0)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 2)
+ assert(bins(1)(1).highSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(1)(2).category === 2.0)
+ assert(bins(1)(2).lowSplit.categories.length === 2)
+ assert(bins(1)(2).lowSplit.categories.contains(0.0))
+ assert(bins(1)(2).lowSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.length === 3)
+ assert(bins(1)(2).highSplit.categories.contains(0.0))
+ assert(bins(1)(2).highSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.contains(2.0))
+
+ assert(bins(1)(3) === null)
+ }
+
+ test("classification stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("regression stump with all categorical variables") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Regression,
+ Variance,
+ maxDepth = 3,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+
+ val split = bestSplits(0)._1
+ assert(split.categories.length === 1)
+ assert(split.categories.contains(1.0))
+ assert(split.featureType === Categorical)
+ assert(split.threshold === Double.MinValue)
+
+ val stats = bestSplits(0)._2
+ assert(stats.gain > 0)
+ assert(stats.predict > 0.4)
+ assert(stats.predict < 0.5)
+ assert(stats.impurity > 0.2)
+ }
+
+ test("stump with fixed label 0 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ }
+
+ test("stump with fixed label 1 for Gini") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Gini, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+
+ test("stump with fixed label 0 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 0)
+ }
+
+ test("stump with fixed label 1 for Entropy") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+ assert(splits(0).length === 99)
+ assert(bins(0).length === 100)
+
+ val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ Array[List[Filter]](), splits, bins)
+ assert(bestSplits.length === 1)
+ assert(bestSplits(0)._1.feature === 0)
+ assert(bestSplits(0)._1.threshold === 10)
+ assert(bestSplits(0)._2.gain === 0)
+ assert(bestSplits(0)._2.leftImpurity === 0)
+ assert(bestSplits(0)._2.rightImpurity === 0)
+ assert(bestSplits(0)._2.predict === 1)
+ }
+}
+
+object DecisionTreeSuite {
+
+ def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
+ arr(i) = lp
+ }
+ arr
+ }
+
+ def generateCategoricalDataPoints(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](1000)
+ for (i <- 0 until 1000){
+ if (i < 600){
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
+ } else {
+ arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))
+ }
+ }
+ arr
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 60f053b381305..27d41c7869aa0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -17,14 +17,20 @@
package org.apache.spark.mllib.util
+import java.io.File
+
import org.scalatest.FunSuite
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
squaredDistance => breezeSquaredDistance}
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
-class MLUtilsSuite extends FunSuite {
+class MLUtilsSuite extends FunSuite with LocalSparkContext {
test("epsilon computation") {
assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
@@ -49,4 +55,55 @@ class MLUtilsSuite extends FunSuite {
assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
}
}
+
+ test("compute stats") {
+ val data = Seq.fill(3)(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
+ LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
+ )).flatten
+ val rdd = sc.parallelize(data, 2)
+ val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
+ assert(meanLabel === 0.5)
+ assert(mean === Vectors.dense(2.0, 3.0, 4.0))
+ assert(std === Vectors.dense(1.0, 1.0, 1.0))
+ }
+
+ test("loadLibSVMData") {
+ val lines =
+ """
+ |+1 1:1.0 3:2.0 5:3.0
+ |-1
+ |-1 2:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Files.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
+ val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
+
+ for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
+ assert(points.length === 3)
+ assert(points(0).label === 1.0)
+ assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
+ assert(points(1).label == 0.0)
+ assert(points(1).features == Vectors.sparse(6, Seq()))
+ assert(points(2).label === 0.0)
+ assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
+ }
+
+ val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
+ assert(multiclassPoints.length === 3)
+ assert(multiclassPoints(0).label === 1.0)
+ assert(multiclassPoints(1).label === -1.0)
+ assert(multiclassPoints(2).label === -1.0)
+
+ try {
+ file.delete()
+ tempDir.delete()
+ } catch {
+ case t: Throwable =>
+ }
+ }
}
diff --git a/pom.xml b/pom.xml
index deb89b18ada73..c03bb35c99442 100644
--- a/pom.xml
+++ b/pom.xml
@@ -21,7 +21,7 @@
org.apache
apache
- 13
+ 14
org.apache.spark
spark-parent
@@ -54,11 +54,11 @@
JIRA
- https://spark-project.atlassian.net/browse/SPARK
+ https://issues.apache.org/jira/browse/SPARK
- 3.0.0
+ 3.0.4
@@ -110,7 +110,7 @@
1.6
- 2.10.3
+ 2.10.4
2.10
0.13.0
org.spark-project.akka
@@ -123,6 +123,10 @@
0.94.6
0.12.0
1.3.2
+ 1.2.3
+ 8.1.14.v20131031
+ 0.3.1
+ 3.0.0
64m
512m
@@ -192,22 +196,22 @@
org.eclipse.jetty
jetty-util
- 7.6.8.v20121106
+ ${jetty.version}
org.eclipse.jetty
jetty-security
- 7.6.8.v20121106
+ ${jetty.version}
org.eclipse.jetty
jetty-plus
- 7.6.8.v20121106
+ ${jetty.version}
org.eclipse.jetty
jetty-server
- 7.6.8.v20121106
+ ${jetty.version}
com.google.guava
@@ -273,7 +277,7 @@
com.twitter
chill_${scala.binary.version}
- 0.3.1
+ ${chill.version}
org.ow2.asm
@@ -288,7 +292,7 @@
com.twitter
chill-java
- 0.3.1
+ ${chill.version}
org.ow2.asm
@@ -373,14 +377,13 @@
org.apache.derby
derby
10.4.2.0
- test
net.liftweb
lift-json_${scala.binary.version}
2.5.1
@@ -392,33 +395,38 @@
com.codahale.metrics
metrics-core
- 3.0.0
+ ${codahale.metrics.version}
com.codahale.metrics
metrics-jvm
- 3.0.0
+ ${codahale.metrics.version}
com.codahale.metrics
metrics-json
- 3.0.0
+ ${codahale.metrics.version}
com.codahale.metrics
metrics-ganglia
- 3.0.0
+ ${codahale.metrics.version}
com.codahale.metrics
metrics-graphite
- 3.0.0
+ ${codahale.metrics.version}
org.scala-lang
scala-compiler
${scala.version}
+
+ org.scala-lang
+ scala-reflect
+ ${scala.version}
+
org.scala-lang
jline
@@ -571,6 +579,12 @@
+
+
+ org.codehaus.jackson
+ jackson-mapper-asl
+ 1.8.8
+
@@ -580,7 +594,7 @@
org.apache.maven.plugins
maven-enforcer-plugin
- 1.1.1
+ 1.3.1
enforce-versions
@@ -590,7 +604,7 @@
- 3.0.0
+ 3.0.4
${java.version}
@@ -603,12 +617,12 @@
org.codehaus.mojo
build-helper-maven-plugin
- 1.7
+ 1.8
net.alchim31.maven
scala-maven-plugin
- 3.1.5
+ 3.1.6
scala-compile-first
@@ -641,7 +655,6 @@
-deprecation
- -Xms64m
-Xms1024m
-Xmx1024m
-XX:PermSize=${PermGen}
@@ -670,7 +683,7 @@
org.apache.maven.plugins
maven-surefire-plugin
- 2.12.4
+ 2.17
true
@@ -684,7 +697,7 @@
${project.build.directory}/surefire-reports
.
${project.build.directory}/SparkTestSuite.txt
- -Xms64m -Xmx3g
+ -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
@@ -709,7 +722,7 @@
org.apache.maven.plugins
maven-shade-plugin
- 2.0
+ 2.2
org.apache.maven.plugins
@@ -806,7 +819,6 @@
org.apache.maven.plugins
maven-jar-plugin
- 2.4
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 60f14ba37e35c..6b8740d9f21a1 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -17,7 +17,7 @@
import sbt._
import sbt.Classpaths.publishTask
-import Keys._
+import sbt.Keys._
import sbtassembly.Plugin._
import AssemblyKeys._
import scala.util.Properties
@@ -27,10 +27,10 @@ import com.typesafe.tools.mima.plugin.MimaKeys.previousArtifact
import scala.collection.JavaConversions._
// For Sonatype publishing
-//import com.jsuereth.pgp.sbtplugin.PgpKeys._
+// import com.jsuereth.pgp.sbtplugin.PgpKeys._
object SparkBuild extends Build {
- val SPARK_VERSION = "1.0.0-SNAPSHOT"
+ val SPARK_VERSION = "1.0.0-SNAPSHOT"
// Hadoop version to build against. For example, "1.0.4" for Apache releases, or
// "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set
@@ -43,6 +43,8 @@ object SparkBuild extends Build {
val DEFAULT_YARN = false
+ val DEFAULT_HIVE = false
+
// HBase version; set as appropriate.
val HBASE_VERSION = "0.94.6"
@@ -67,15 +69,17 @@ object SparkBuild extends Build {
lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core, catalyst)
- // Since hive is its own assembly, it depends on all of the modules.
- lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql, graphx, bagel, mllib, streaming, repl)
+ lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql)
+
+ lazy val maybeHive: Seq[ClasspathDependency] = if (isHiveEnabled) Seq(hive) else Seq()
+ lazy val maybeHiveRef: Seq[ProjectReference] = if (isHiveEnabled) Seq(hive) else Seq()
lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn(core)
lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core)
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
- .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeGanglia: _*)
+ .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*)
lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
@@ -101,6 +105,11 @@ object SparkBuild extends Build {
lazy val hadoopClient = if (hadoopVersion.startsWith("0.20.") || hadoopVersion == "1.0.0") "hadoop-core" else "hadoop-client"
val maybeAvro = if (hadoopVersion.startsWith("0.23.") && isYarnEnabled) Seq("org.apache.avro" % "avro" % "1.7.4") else Seq()
+ lazy val isHiveEnabled = Properties.envOrNone("SPARK_HIVE") match {
+ case None => DEFAULT_HIVE
+ case Some(v) => v.toBoolean
+ }
+
// Include Ganglia integration if the user has enabled Ganglia
// This is isolated from the normal build due to LGPL-licensed code in the library
lazy val isGangliaEnabled = Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined
@@ -141,18 +150,18 @@ object SparkBuild extends Build {
lazy val allExternalRefs = Seq[ProjectReference](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt)
lazy val examples = Project("examples", file("examples"), settings = examplesSettings)
- .dependsOn(core, mllib, graphx, bagel, streaming, externalTwitter, hive) dependsOn(allExternal: _*)
+ .dependsOn(core, mllib, graphx, bagel, streaming, hive) dependsOn(allExternal: _*)
// Everything except assembly, hive, tools, java8Tests and examples belong to packageProjects
- lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx, catalyst, sql) ++ maybeYarnRef ++ maybeGangliaRef
+ lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx, catalyst, sql) ++ maybeYarnRef ++ maybeHiveRef ++ maybeGangliaRef
lazy val allProjects = packageProjects ++ allExternalRefs ++
- Seq[ProjectReference](examples, tools, assemblyProj, hive) ++ maybeJava8Tests
+ Seq[ProjectReference](examples, tools, assemblyProj) ++ maybeJava8Tests
def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq(
organization := "org.apache.spark",
version := SPARK_VERSION,
- scalaVersion := "2.10.3",
+ scalaVersion := "2.10.4",
scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation",
"-target:" + SCALAC_JVM_VERSION),
javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION),
@@ -169,6 +178,7 @@ object SparkBuild extends Build {
fork := true,
javaOptions in Test += "-Dspark.home=" + sparkHome,
javaOptions in Test += "-Dspark.testing=1",
+ javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark").map { case (k,v) => s"-D$k=$v" }.toSeq,
javaOptions += "-Xmx3g",
// Show full stack trace and duration in test cases.
@@ -185,22 +195,21 @@ object SparkBuild extends Build {
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
resolvers ++= Seq(
- // HTTPS is unavailable for Maven Central
"Maven Repository" at "http://repo.maven.apache.org/maven2",
"Apache Repository" at "https://repository.apache.org/content/repositories/releases",
"JBoss Repository" at "https://repository.jboss.org/nexus/content/repositories/releases/",
"MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/",
- "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/",
+ "Cloudera Repository" at "http://repository.cloudera.com/artifactory/cloudera-repos/",
// For Sonatype publishing
- //"sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
- //"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/",
+ // "sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
+ // "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/",
// also check the local Maven repository ~/.m2
Resolver.mavenLocal
),
publishMavenStyle := true,
- //useGpg in Global := true,
+ // useGpg in Global := true,
pomExtra := (
@@ -248,13 +257,13 @@ object SparkBuild extends Build {
*/
libraryDependencies ++= Seq(
- "io.netty" % "netty-all" % "4.0.17.Final",
- "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106",
- "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106",
- "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106",
- "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106",
+ "io.netty" % "netty-all" % "4.0.17.Final",
+ "org.eclipse.jetty" % "jetty-server" % jettyVersion,
+ "org.eclipse.jetty" % "jetty-util" % jettyVersion,
+ "org.eclipse.jetty" % "jetty-plus" % jettyVersion,
+ "org.eclipse.jetty" % "jetty-security" % jettyVersion,
/** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */
- "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"),
+ "org.eclipse.jetty.orbit" % "javax.servlet" % "3.0.0.v201112011016" artifacts Artifact("javax.servlet", "jar", "jar"),
"org.scalatest" %% "scalatest" % "1.9.1" % "test",
"org.scalacheck" %% "scalacheck" % "1.10.0" % "test",
"com.novocode" % "junit-interface" % "0.10" % "test",
@@ -277,16 +286,28 @@ object SparkBuild extends Build {
publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings ++ ScalaStyleSettings
+ val akkaVersion = "2.2.3-shaded-protobuf"
+ val chillVersion = "0.3.1"
+ val codahaleMetricsVersion = "3.0.0"
+ val jblasVersion = "1.2.3"
+ val jettyVersion = "8.1.14.v20131031"
+ val hiveVersion = "0.12.0"
+ val parquetVersion = "1.3.2"
val slf4jVersion = "1.7.5"
val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
+ val excludeEclipseJetty = ExclusionRule(organization = "org.eclipse.jetty")
val excludeAsm = ExclusionRule(organization = "org.ow2.asm")
val excludeOldAsm = ExclusionRule(organization = "asm")
val excludeCommonsLogging = ExclusionRule(organization = "commons-logging")
val excludeSLF4J = ExclusionRule(organization = "org.slf4j")
val excludeScalap = ExclusionRule(organization = "org.scala-lang", artifact = "scalap")
+ val excludeHadoop = ExclusionRule(organization = "org.apache.hadoop")
+ val excludeCurator = ExclusionRule(organization = "org.apache.curator")
+ val excludePowermock = ExclusionRule(organization = "org.powermock")
+
- def sparkPreviousArtifact(id: String, organization: String = "org.apache.spark",
+ def sparkPreviousArtifact(id: String, organization: String = "org.apache.spark",
version: String = "0.9.0-incubating", crossVersion: String = "2.10"): Option[sbt.ModuleID] = {
val fullId = if (crossVersion.isEmpty) id else id + "_" + crossVersion
Some(organization % fullId % version) // the artifact to compare binary compatibility with
@@ -305,9 +326,9 @@ object SparkBuild extends Build {
"commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407
"com.ning" % "compress-lzf" % "1.0.0",
"org.xerial.snappy" % "snappy-java" % "1.0.5",
- "org.spark-project.akka" %% "akka-remote" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty),
- "org.spark-project.akka" %% "akka-slf4j" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty),
- "org.spark-project.akka" %% "akka-testkit" % "2.2.3-shaded-protobuf" % "test",
+ "org.spark-project.akka" %% "akka-remote" % akkaVersion excludeAll(excludeNetty),
+ "org.spark-project.akka" %% "akka-slf4j" % akkaVersion excludeAll(excludeNetty),
+ "org.spark-project.akka" %% "akka-testkit" % akkaVersion % "test",
"org.json4s" %% "json4s-jackson" % "3.2.6" excludeAll(excludeScalap),
"it.unimi.dsi" % "fastutil" % "6.4.4",
"colt" % "colt" % "1.2.0",
@@ -317,12 +338,13 @@ object SparkBuild extends Build {
"org.apache.derby" % "derby" % "10.4.2.0" % "test",
"org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J, excludeOldAsm),
"org.apache.curator" % "curator-recipes" % "2.4.0" excludeAll(excludeNetty),
- "com.codahale.metrics" % "metrics-core" % "3.0.0",
- "com.codahale.metrics" % "metrics-jvm" % "3.0.0",
- "com.codahale.metrics" % "metrics-json" % "3.0.0",
- "com.codahale.metrics" % "metrics-graphite" % "3.0.0",
- "com.twitter" %% "chill" % "0.3.1" excludeAll(excludeAsm),
- "com.twitter" % "chill-java" % "0.3.1" excludeAll(excludeAsm),
+ "com.codahale.metrics" % "metrics-core" % codahaleMetricsVersion,
+ "com.codahale.metrics" % "metrics-jvm" % codahaleMetricsVersion,
+ "com.codahale.metrics" % "metrics-json" % codahaleMetricsVersion,
+ "com.codahale.metrics" % "metrics-graphite" % codahaleMetricsVersion,
+ "com.twitter" %% "chill" % chillVersion excludeAll(excludeAsm),
+ "com.twitter" % "chill-java" % chillVersion excludeAll(excludeAsm),
+ "org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock),
"com.clearspring.analytics" % "stream" % "2.5.1"
),
libraryDependencies ++= maybeAvro
@@ -356,14 +378,16 @@ object SparkBuild extends Build {
) ++ assemblySettings ++ extraAssemblySettings
def toolsSettings = sharedSettings ++ Seq(
- name := "spark-tools"
+ name := "spark-tools",
+ libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v ),
+ libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v )
) ++ assemblySettings ++ extraAssemblySettings
def graphxSettings = sharedSettings ++ Seq(
name := "spark-graphx",
previousArtifact := sparkPreviousArtifact("spark-graphx"),
libraryDependencies ++= Seq(
- "org.jblas" % "jblas" % "1.2.3"
+ "org.jblas" % "jblas" % jblasVersion
)
)
@@ -376,7 +400,7 @@ object SparkBuild extends Build {
name := "spark-mllib",
previousArtifact := sparkPreviousArtifact("spark-mllib"),
libraryDependencies ++= Seq(
- "org.jblas" % "jblas" % "1.2.3",
+ "org.jblas" % "jblas" % jblasVersion,
"org.scalanlp" %% "breeze" % "0.7"
)
)
@@ -396,22 +420,20 @@ object SparkBuild extends Build {
def sqlCoreSettings = sharedSettings ++ Seq(
name := "spark-sql",
libraryDependencies ++= Seq(
- "com.twitter" % "parquet-column" % "1.3.2",
- "com.twitter" % "parquet-hadoop" % "1.3.2"
+ "com.twitter" % "parquet-column" % parquetVersion,
+ "com.twitter" % "parquet-hadoop" % parquetVersion
)
)
// Since we don't include hive in the main assembly this project also acts as an alternative
// assembly jar.
- def hiveSettings = sharedSettings ++ assemblyProjSettings ++ Seq(
+ def hiveSettings = sharedSettings ++ Seq(
name := "spark-hive",
- jarName in assembly <<= version map { v => "spark-hive-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
- jarName in packageDependency <<= version map { v => "spark-hive-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" },
javaOptions += "-XX:MaxPermSize=1g",
libraryDependencies ++= Seq(
- "org.apache.hive" % "hive-metastore" % "0.12.0",
- "org.apache.hive" % "hive-exec" % "0.12.0",
- "org.apache.hive" % "hive-serde" % "0.12.0"
+ "org.apache.hive" % "hive-metastore" % hiveVersion,
+ "org.apache.hive" % "hive-exec" % hiveVersion,
+ "org.apache.hive" % "hive-serde" % hiveVersion
),
// Multiple queries rely on the TestHive singleton. See comments there for more details.
parallelExecution in Test := false,
@@ -542,7 +564,7 @@ object SparkBuild extends Build {
name := "spark-streaming-zeromq",
previousArtifact := sparkPreviousArtifact("spark-streaming-zeromq"),
libraryDependencies ++= Seq(
- "org.spark-project.akka" %% "akka-zeromq" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty)
+ "org.spark-project.akka" %% "akka-zeromq" % akkaVersion excludeAll(excludeNetty)
)
)
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 4ff6f67af45c0..d787237ddc540 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -1,4 +1,4 @@
-scalaVersion := "2.10.3"
+scalaVersion := "2.10.4"
resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns)
@@ -22,3 +22,4 @@ addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0")
addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.0")
+
diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala
new file mode 100644
index 0000000000000..0142256e90fb7
--- /dev/null
+++ b/project/project/SparkPluginBuild.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import sbt._
+import sbt.Keys._
+
+/**
+ * This plugin project is there to define new scala style rules for spark. This is
+ * a plugin project so that this gets compiled first and is put on the classpath and
+ * becomes available for scalastyle sbt plugin.
+ */
+object SparkPluginDef extends Build {
+ lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle)
+ lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings)
+ val sparkVersion = "1.0.0-SNAPSHOT"
+ // There is actually no need to publish this artifact.
+ def styleSettings = Defaults.defaultSettings ++ Seq (
+ name := "spark-style",
+ organization := "org.apache.spark",
+ version := sparkVersion,
+ scalaVersion := "2.10.4",
+ scalacOptions := Seq("-unchecked", "-deprecation"),
+ libraryDependencies ++= Dependencies.scalaStyle
+ )
+
+ object Dependencies {
+ val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0")
+ }
+}
diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala
new file mode 100644
index 0000000000000..80d3faa3fe749
--- /dev/null
+++ b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.scalastyle
+
+import java.util.regex.Pattern
+
+import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError}
+import scalariform.lexer.{MultiLineComment, ScalaDocComment, SingleLineComment, Token}
+import scalariform.parser.CompilationUnit
+
+class SparkSpaceAfterCommentStartChecker extends ScalariformChecker {
+ val errorKey: String = "insert.a.single.space.after.comment.start.and.before.end"
+
+ private def multiLineCommentRegex(comment: Token) =
+ Pattern.compile( """/\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() ||
+ Pattern.compile( """/\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches()
+
+ private def scalaDocPatternRegex(comment: Token) =
+ Pattern.compile( """/\*\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() ||
+ Pattern.compile( """/\*\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches()
+
+ private def singleLineCommentRegex(comment: Token): Boolean =
+ comment.text.trim.matches( """//\S+.*""") && !comment.text.trim.matches( """///+""")
+
+ override def verify(ast: CompilationUnit): List[ScalastyleError] = {
+ ast.tokens
+ .filter(hasComment)
+ .map {
+ _.associatedWhitespaceAndComments.comments.map {
+ case x: SingleLineComment if singleLineCommentRegex(x.token) => Some(x.token.offset)
+ case x: MultiLineComment if multiLineCommentRegex(x.token) => Some(x.token.offset)
+ case x: ScalaDocComment if scalaDocPatternRegex(x.token) => Some(x.token.offset)
+ case _ => None
+ }.flatten
+ }.flatten.map(PositionError(_))
+ }
+
+
+ private def hasComment(x: Token) =
+ x.associatedWhitespaceAndComments != null && !x.associatedWhitespaceAndComments.comments.isEmpty
+
+}
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index bf2454fd7e38e..d8667e84fedff 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -28,7 +28,8 @@
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
+from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
+ PairDeserializer
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
@@ -257,6 +258,45 @@ def textFile(self, name, minSplits=None):
return RDD(self._jsc.textFile(name, minSplits), self,
UTF8Deserializer())
+ def wholeTextFiles(self, path):
+ """
+ Read a directory of text files from HDFS, a local file system
+ (available on all nodes), or any Hadoop-supported file system
+ URI. Each file is read as a single record and returned in a
+ key-value pair, where the key is the path of each file, the
+ value is the content of each file.
+
+ For example, if you have the following files::
+
+ hdfs://a-hdfs-path/part-00000
+ hdfs://a-hdfs-path/part-00001
+ ...
+ hdfs://a-hdfs-path/part-nnnnn
+
+ Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")},
+ then C{rdd} contains::
+
+ (a-hdfs-path/part-00000, its content)
+ (a-hdfs-path/part-00001, its content)
+ ...
+ (a-hdfs-path/part-nnnnn, its content)
+
+ NOTE: Small files are preferred, as each file will be loaded
+ fully in memory.
+
+ >>> dirPath = os.path.join(tempdir, "files")
+ >>> os.mkdir(dirPath)
+ >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
+ ... file1.write("1")
+ >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
+ ... file2.write("2")
+ >>> textFiles = sc.wholeTextFiles(dirPath)
+ >>> sorted(textFiles.collect())
+ [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
+ """
+ return RDD(self._jsc.wholeTextFiles(path), self,
+ PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))
+
def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self, input_deserializer)
@@ -383,8 +423,11 @@ def _getJavaStorageLevel(self, storageLevel):
raise Exception("storageLevel must be of type pyspark.StorageLevel")
newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel
- return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory,
- storageLevel.deserialized, storageLevel.replication)
+ return newStorageLevel(storageLevel.useDisk,
+ storageLevel.useMemory,
+ storageLevel.useOffHeap,
+ storageLevel.deserialized,
+ storageLevel.replication)
def setJobGroup(self, groupId, description):
"""
@@ -425,7 +468,7 @@ def _test():
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
- (failure_count, test_count) = doctest.testmod(globs=globs)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index b420d7a7f23ba..538ff26ce7c33 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -19,11 +19,7 @@
Python bindings for MLlib.
"""
-# MLlib currently needs Python 2.7+ and NumPy 1.7+, so complain if lower
-
-import sys
-if sys.version_info[0:2] < (2, 7):
- raise Exception("MLlib requires Python 2.7+")
+# MLlib currently needs and NumPy 1.7+, so complain if lower
import numpy
if numpy.version.version < '1.7':
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 19b90dfd6e167..d2f9cdb3f4298 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -87,18 +87,19 @@ class NaiveBayesModel(object):
>>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3)
>>> model = NaiveBayes.train(sc.parallelize(data))
>>> model.predict(array([0.0, 1.0]))
- 0
+ 0.0
>>> model.predict(array([1.0, 0.0]))
- 1
+ 1.0
"""
- def __init__(self, pi, theta):
+ def __init__(self, labels, pi, theta):
+ self.labels = labels
self.pi = pi
self.theta = theta
def predict(self, x):
"""Return the most likely class for a data vector x"""
- return numpy.argmax(self.pi + dot(x, self.theta))
+ return self.labels[numpy.argmax(self.pi + dot(x, self.theta))]
class NaiveBayes(object):
@classmethod
@@ -122,7 +123,8 @@ def train(cls, data, lambda_=1.0):
ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_)
return NaiveBayesModel(
_deserialize_double_vector(ans[0]),
- _deserialize_double_matrix(ans[1]))
+ _deserialize_double_vector(ans[1]),
+ _deserialize_double_matrix(ans[2]))
def _test():
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 019c249699c2d..fb27863e07f55 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -29,7 +29,7 @@
from tempfile import NamedTemporaryFile
from threading import Thread
import warnings
-from heapq import heappush, heappop, heappushpop
+import heapq
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -41,9 +41,9 @@
from py4j.java_collections import ListConverter, MapConverter
-
__all__ = ["RDD"]
+
def _extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
@@ -91,6 +91,73 @@ def __exit__(self, type, value, tb):
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)
+class MaxHeapQ(object):
+ """
+ An implementation of MaxHeap.
+ >>> import pyspark.rdd
+ >>> heap = pyspark.rdd.MaxHeapQ(5)
+ >>> [heap.insert(i) for i in range(10)]
+ [None, None, None, None, None, None, None, None, None, None]
+ >>> sorted(heap.getElements())
+ [0, 1, 2, 3, 4]
+ >>> heap = pyspark.rdd.MaxHeapQ(5)
+ >>> [heap.insert(i) for i in range(9, -1, -1)]
+ [None, None, None, None, None, None, None, None, None, None]
+ >>> sorted(heap.getElements())
+ [0, 1, 2, 3, 4]
+ >>> heap = pyspark.rdd.MaxHeapQ(1)
+ >>> [heap.insert(i) for i in range(9, -1, -1)]
+ [None, None, None, None, None, None, None, None, None, None]
+ >>> heap.getElements()
+ [0]
+ """
+
+ def __init__(self, maxsize):
+ # we start from q[1], this makes calculating children as trivial as 2 * k
+ self.q = [0]
+ self.maxsize = maxsize
+
+ def _swim(self, k):
+ while (k > 1) and (self.q[k/2] < self.q[k]):
+ self._swap(k, k/2)
+ k = k/2
+
+ def _swap(self, i, j):
+ t = self.q[i]
+ self.q[i] = self.q[j]
+ self.q[j] = t
+
+ def _sink(self, k):
+ N = self.size()
+ while 2 * k <= N:
+ j = 2 * k
+ # Here we test if both children are greater than parent
+ # if not swap with larger one.
+ if j < N and self.q[j] < self.q[j + 1]:
+ j = j + 1
+ if(self.q[k] > self.q[j]):
+ break
+ self._swap(k, j)
+ k = j
+
+ def size(self):
+ return len(self.q) - 1
+
+ def insert(self, value):
+ if (self.size()) < self.maxsize:
+ self.q.append(value)
+ self._swim(self.size())
+ else:
+ self._replaceRoot(value)
+
+ def getElements(self):
+ return self.q[1:]
+
+ def _replaceRoot(self, value):
+ if(self.q[1] > value):
+ self.q[1] = value
+ self._sink(1)
+
class RDD(object):
"""
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
@@ -696,16 +763,16 @@ def top(self, num):
Note: It returns the list sorted in descending order.
>>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
[12]
- >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2)
+ >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2)
[6, 5]
"""
def topIterator(iterator):
q = []
for k in iterator:
if len(q) < num:
- heappush(q, k)
+ heapq.heappush(q, k)
else:
- heappushpop(q, k)
+ heapq.heappushpop(q, k)
yield q
def merge(a, b):
@@ -713,6 +780,36 @@ def merge(a, b):
return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
+ def takeOrdered(self, num, key=None):
+ """
+ Get the N elements from a RDD ordered in ascending order or as specified
+ by the optional key function.
+
+ >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
+ [1, 2, 3, 4, 5, 6]
+ >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x)
+ [10, 9, 7, 6, 5, 4]
+ """
+
+ def topNKeyedElems(iterator, key_=None):
+ q = MaxHeapQ(num)
+ for k in iterator:
+ if key_ != None:
+ k = (key_(k), k)
+ q.insert(k)
+ yield q.getElements()
+
+ def unKey(x, key_=None):
+ if key_ != None:
+ x = [i[1] for i in x]
+ return x
+
+ def merge(a, b):
+ return next(topNKeyedElems(a + b))
+ result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
+ return sorted(unKey(result, key), key=key)
+
+
def take(self, num):
"""
Take the first num elements of the RDD.
@@ -1205,11 +1302,12 @@ def getStorageLevel(self):
Get the RDD's current storage level.
>>> rdd1 = sc.parallelize([1,2])
>>> rdd1.getStorageLevel()
- StorageLevel(False, False, False, 1)
+ StorageLevel(False, False, False, False, 1)
"""
java_storage_level = self._jrdd.getStorageLevel()
storage_level = StorageLevel(java_storage_level.useDisk(),
java_storage_level.useMemory(),
+ java_storage_level.useOffHeap(),
java_storage_level.deserialized(),
java_storage_level.replication())
return storage_level
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 12c63f186a2b7..b253807974a2e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -64,6 +64,7 @@
from itertools import chain, izip, product
import marshal
import struct
+import sys
from pyspark import cloudpickle
@@ -113,6 +114,11 @@ class FramedSerializer(Serializer):
where C{length} is a 32-bit integer and data is C{length} bytes.
"""
+ def __init__(self):
+ # On Python 2.6, we can't write bytearrays to streams, so we need to convert them
+ # to strings first. Check if the version number is that old.
+ self._only_write_strings = sys.version_info[0:2] <= (2, 6)
+
def dump_stream(self, iterator, stream):
for obj in iterator:
self._write_with_length(obj, stream)
@@ -127,7 +133,10 @@ def load_stream(self, stream):
def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
write_int(len(serialized), stream)
- stream.write(serialized)
+ if self._only_write_strings:
+ stream.write(str(serialized))
+ else:
+ stream.write(serialized)
def _read_with_length(self, stream):
length = read_int(stream)
@@ -290,7 +299,7 @@ class MarshalSerializer(FramedSerializer):
class UTF8Deserializer(Serializer):
"""
- Deserializes streams written by getBytes.
+ Deserializes streams written by String.getBytes.
"""
def loads(self, stream):
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 3d779faf1fa44..35e48276e3cb9 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -29,7 +29,7 @@
# this is the equivalent of ADD_JARS
add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None
-sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files)
+sc = SparkContext(os.environ.get("MASTER", "local[*]"), "PySparkShell", pyFiles=add_files)
print """Welcome to
____ __
diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py
index c3e3a44e8e7ab..7b6660eab231b 100644
--- a/python/pyspark/storagelevel.py
+++ b/python/pyspark/storagelevel.py
@@ -25,23 +25,25 @@ class StorageLevel:
Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY.
"""
- def __init__(self, useDisk, useMemory, deserialized, replication = 1):
+ def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication = 1):
self.useDisk = useDisk
self.useMemory = useMemory
+ self.useOffHeap = useOffHeap
self.deserialized = deserialized
self.replication = replication
def __repr__(self):
- return "StorageLevel(%s, %s, %s, %s)" % (
- self.useDisk, self.useMemory, self.deserialized, self.replication)
+ return "StorageLevel(%s, %s, %s, %s, %s)" % (
+ self.useDisk, self.useMemory, self.useOffHeap, self.deserialized, self.replication)
-StorageLevel.DISK_ONLY = StorageLevel(True, False, False)
-StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, 2)
-StorageLevel.MEMORY_ONLY = StorageLevel(False, True, True)
-StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, True, 2)
-StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False)
-StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, 2)
-StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, True)
-StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, True, 2)
-StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False)
-StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, 2)
+StorageLevel.DISK_ONLY = StorageLevel(True, False, False, False)
+StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, False, 2)
+StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, True)
+StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, True, 2)
+StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False, False)
+StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, False, 2)
+StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, True)
+StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2)
+StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False)
+StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2)
+StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1)
\ No newline at end of file
diff --git a/python/run-tests b/python/run-tests
index a986ac9380be4..b2b60f08b48e2 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -29,8 +29,18 @@ FAILED=0
rm -f unit-tests.log
function run_test() {
- SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a unit-tests.log
+ SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log
FAILED=$((PIPESTATUS[0]||$FAILED))
+
+ # Fail and exit on the first test failure.
+ if [[ $FAILED != 0 ]]; then
+ cat unit-tests.log | grep -v "^[0-9][0-9]*" # filter all lines starting with a number.
+ echo -en "\033[31m" # Red
+ echo "Had test failures; see logs."
+ echo -en "\033[0m" # No color
+ exit -1
+ fi
+
}
run_test "pyspark/rdd.py"
@@ -46,12 +56,7 @@ run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
-if [[ $FAILED != 0 ]]; then
- echo -en "\033[31m" # Red
- echo "Had test failures; see logs."
- echo -en "\033[0m" # No color
- exit -1
-else
+if [[ $FAILED == 0 ]]; then
echo -en "\033[32m" # Green
echo "Tests passed."
echo -en "\033[0m" # No color
diff --git a/repl/pom.xml b/repl/pom.xml
index fc49c8b811316..78d2fe13c27eb 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -77,6 +77,11 @@
scala-compiler
${scala.version}
+
+ org.scala-lang
+ scala-reflect
+ ${scala.version}
+
org.scala-lang
jline
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index ee972887feda6..bf73800388ebf 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -124,8 +124,8 @@ extends ClassVisitor(ASM4, cv) {
mv.visitVarInsn(ALOAD, 0) // load this
mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V")
mv.visitVarInsn(ALOAD, 0) // load this
- //val classType = className.replace('.', '/')
- //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
+ // val classType = className.replace('.', '/')
+ // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
mv.visitInsn(RETURN)
mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed
mv.visitEnd()
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 9b1da195002c2..5a367b6bb79de 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -963,7 +963,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
case Some(m) => m
case None => {
val prop = System.getenv("MASTER")
- if (prop != null) prop else "local"
+ if (prop != null) prop else "local[*]"
}
}
master
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 90a96ad38381e..fa2f1a88c4eb5 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -834,7 +834,7 @@ import org.apache.spark.util.Utils
}
((pos, msg)) :: loop(filtered)
}
- //PRASHANT: This leads to a NoSuchMethodError for _.warnings. Yet to figure out its purpose.
+ // PRASHANT: This leads to a NoSuchMethodError for _.warnings. Yet to figure out its purpose.
// val warnings = loop(run.allConditionalWarnings flatMap (_.warnings))
// if (warnings.nonEmpty)
// mostRecentWarnings = warnings
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
index 946e71039088d..0db26c3407dff 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
@@ -7,8 +7,10 @@
package org.apache.spark.repl
+import scala.reflect.io.{Path, File}
import scala.tools.nsc._
import scala.tools.nsc.interpreter._
+import scala.tools.nsc.interpreter.session.JLineHistory.JLineFileHistory
import scala.tools.jline.console.ConsoleReader
import scala.tools.jline.console.completer._
@@ -25,7 +27,7 @@ class SparkJLineReader(_completion: => Completion) extends InteractiveReader {
val consoleReader = new JLineConsoleReader()
lazy val completion = _completion
- lazy val history: JLineHistory = JLineHistory()
+ lazy val history: JLineHistory = new SparkJLineHistory
private def term = consoleReader.getTerminal()
def reset() = term.reset()
@@ -78,3 +80,11 @@ class SparkJLineReader(_completion: => Completion) extends InteractiveReader {
def readOneLine(prompt: String) = consoleReader readLine prompt
def readOneKey(prompt: String) = consoleReader readOneKey prompt
}
+
+/** Changes the default history file to not collide with the scala repl's. */
+class SparkJLineHistory extends JLineFileHistory {
+ import Properties.userHome
+
+ def defaultFileName = ".spark_history"
+ override protected lazy val historyFile = File(Path(userHome) / defaultFileName)
+}
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 8203b8f6122e1..4155007c6d337 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -242,4 +242,15 @@ class ReplSuite extends FunSuite {
assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
}
}
+
+ test("collecting objects of class defined in repl") {
+ val output = runInterpreter("local[2]",
+ """
+ |case class Foo(i: Int)
+ |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("ret: Array[Foo] = Array(Foo(1),", output)
+ }
}
diff --git a/sbt/sbt b/sbt/sbt
index 3ffa4ed9ab5a7..9de265bd07dcb 100755
--- a/sbt/sbt
+++ b/sbt/sbt
@@ -1,5 +1,13 @@
#!/usr/bin/env bash
+# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so
+# that we can run Hive to generate the golden answer. This is not required for normal development
+# or testing.
+for i in $HIVE_HOME/lib/*
+do HADOOP_CLASSPATH=$HADOOP_CLASSPATH:$i
+done
+export HADOOP_CLASSPATH
+
realpath () {
(
TARGET_FILE=$1
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index ee968c53b3e4b..76ba1ecca33ab 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -140,4 +140,5 @@
+
diff --git a/sql/README.md b/sql/README.md
index 4192fecb92fb0..14d5555f0c713 100644
--- a/sql/README.md
+++ b/sql/README.md
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.TestHive._
-Welcome to Scala version 2.10.3 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45).
+Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45).
Type in expressions to have them evaluated.
Type :help for more information.
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 740f1fdc83299..9d5c6a857bb00 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -31,7 +31,23 @@
Spark Project Catalyst
http://spark.apache.org/
+
+
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
+
+
+ org.scala-lang
+ scala-reflect
+
org.apache.spark
spark-core_${scala.binary.version}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 976dda8d7e59a..446d0e0bd7f54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst
+import java.sql.Timestamp
+
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
@@ -43,15 +45,26 @@ object ScalaReflection {
val params = t.member("": TermName).asMethod.paramss
StructType(
params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
+ // Need to decide if we actually need a special type here.
+ case t if t <:< typeOf[Array[Byte]] => BinaryType
+ case t if t <:< typeOf[Array[_]] =>
+ sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
ArrayType(schemaFor(elementType))
+ case t if t <:< typeOf[Map[_,_]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ MapType(schemaFor(keyType), schemaFor(valueType))
case t if t <:< typeOf[String] => StringType
+ case t if t <:< typeOf[Timestamp] => TimestampType
+ case t if t <:< typeOf[BigDecimal] => DecimalType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
+ case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
+ case t if t <:< definitions.BooleanTpe => BooleanType
}
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 9dec4e3d9e4c2..5b6aea81cb7d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -106,6 +106,8 @@ class SqlParser extends StandardTokenParsers {
protected val IF = Keyword("IF")
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
+ protected val INSERT = Keyword("INSERT")
+ protected val INTO = Keyword("INTO")
protected val IS = Keyword("IS")
protected val JOIN = Keyword("JOIN")
protected val LEFT = Keyword("LEFT")
@@ -114,6 +116,10 @@ class SqlParser extends StandardTokenParsers {
protected val NULL = Keyword("NULL")
protected val ON = Keyword("ON")
protected val OR = Keyword("OR")
+ protected val OVERWRITE = Keyword("OVERWRITE")
+ protected val LIKE = Keyword("LIKE")
+ protected val RLIKE = Keyword("RLIKE")
+ protected val REGEXP = Keyword("REGEXP")
protected val ORDER = Keyword("ORDER")
protected val OUTER = Keyword("OUTER")
protected val RIGHT = Keyword("RIGHT")
@@ -159,7 +165,7 @@ class SqlParser extends StandardTokenParsers {
select * (
UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
- )
+ ) | insert
protected lazy val select: Parser[LogicalPlan] =
SELECT ~> opt(DISTINCT) ~ projections ~
@@ -178,10 +184,17 @@ class SqlParser extends StandardTokenParsers {
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving)
- val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder)
+ val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder)
withLimit
}
+ protected lazy val insert: Parser[LogicalPlan] =
+ INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ {
+ case o ~ r ~ s =>
+ val overwrite: Boolean = o.getOrElse("") == "OVERWRITE"
+ InsertIntoTable(r, Map[String, Option[String]](), s, overwrite)
+ }
+
protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",")
protected lazy val projection: Parser[Expression] =
@@ -192,6 +205,8 @@ class SqlParser extends StandardTokenParsers {
protected lazy val from: Parser[LogicalPlan] = FROM ~> relations
+ protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation
+
// Based very loosely on the MySQL Grammar.
// http://dev.mysql.com/doc/refman/5.0/en/join.html
protected lazy val relations: Parser[LogicalPlan] =
@@ -204,7 +219,7 @@ class SqlParser extends StandardTokenParsers {
protected lazy val relationFactor: Parser[LogicalPlan] =
ident ~ (opt(AS) ~> opt(ident)) ^^ {
- case ident ~ alias => UnresolvedRelation(alias, ident)
+ case tableName ~ alias => UnresolvedRelation(None, tableName, alias)
} |
"(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) }
@@ -267,6 +282,9 @@ class SqlParser extends StandardTokenParsers {
termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } |
termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } |
termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } |
+ termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
+ termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } |
+ termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } |
termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ {
case e1 ~ _ ~ _ ~ e2 => In(e1, e2)
} |
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index ff66177a03b8c..f30b5d816703a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -31,26 +31,42 @@ trait Catalog {
alias: Option[String] = None): LogicalPlan
def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit
+
+ def unregisterTable(databaseName: Option[String], tableName: String): Unit
+
+ def unregisterAllTables(): Unit
}
class SimpleCatalog extends Catalog {
val tables = new mutable.HashMap[String, LogicalPlan]()
- def registerTable(databaseName: Option[String],tableName: String, plan: LogicalPlan): Unit = {
+ override def registerTable(
+ databaseName: Option[String],
+ tableName: String,
+ plan: LogicalPlan): Unit = {
tables += ((tableName, plan))
}
- def dropTable(tableName: String) = tables -= tableName
+ override def unregisterTable(
+ databaseName: Option[String],
+ tableName: String) = {
+ tables -= tableName
+ }
- def lookupRelation(
+ override def unregisterAllTables() = {
+ tables.clear()
+ }
+
+ override def lookupRelation(
databaseName: Option[String],
tableName: String,
alias: Option[String] = None): LogicalPlan = {
val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName"))
+ val tableWithQualifiers = Subquery(tableName, table)
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
// properly qualified with this alias.
- alias.map(a => Subquery(a.toLowerCase, table)).getOrElse(table)
+ alias.map(a => Subquery(a.toLowerCase, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}
}
@@ -86,6 +102,14 @@ trait OverrideCatalog extends Catalog {
plan: LogicalPlan): Unit = {
overrides.put((databaseName, tableName), plan)
}
+
+ override def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
+ overrides.remove((databaseName, tableName))
+ }
+
+ override def unregisterAllTables(): Unit = {
+ overrides.clear()
+ }
}
/**
@@ -103,4 +127,10 @@ object EmptyCatalog extends Catalog {
def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
+
+ def unregisterTable(databaseName: Option[String], tableName: String): Unit = {
+ throw new UnsupportedOperationException
+ }
+
+ override def unregisterAllTables(): Unit = {}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 67cddb351c185..2d62e4cbbce01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst
+import java.sql.Timestamp
+
import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -68,7 +70,11 @@ package object dsl {
def > (other: Expression) = GreaterThan(expr, other)
def >= (other: Expression) = GreaterThanOrEqual(expr, other)
def === (other: Expression) = Equals(expr, other)
- def != (other: Expression) = Not(Equals(expr, other))
+ def !== (other: Expression) = Not(Equals(expr, other))
+
+ def like(other: Expression) = Like(expr, other)
+ def rlike(other: Expression) = RLike(expr, other)
+ def cast(to: DataType) = Cast(expr, to)
def asc = SortOrder(expr, Ascending)
def desc = SortOrder(expr, Descending)
@@ -81,27 +87,64 @@ package object dsl {
def expr = e
}
+ implicit def booleanToLiteral(b: Boolean) = Literal(b)
+ implicit def byteToLiteral(b: Byte) = Literal(b)
+ implicit def shortToLiteral(s: Short) = Literal(s)
implicit def intToLiteral(i: Int) = Literal(i)
implicit def longToLiteral(l: Long) = Literal(l)
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
+ implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
+ implicit def timestampToLiteral(t: Timestamp) = Literal(t)
+ implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
- implicit class DslString(val s: String) extends ImplicitAttribute
+ // TODO more implicit class for literal?
+ implicit class DslString(val s: String) extends ImplicitOperators {
+ def expr: Expression = Literal(s)
+ def attr = analysis.UnresolvedAttribute(s)
+ }
abstract class ImplicitAttribute extends ImplicitOperators {
def s: String
def expr = attr
def attr = analysis.UnresolvedAttribute(s)
- /** Creates a new typed attributes of type int */
+ /** Creates a new AttributeReference of type boolean */
+ def boolean = AttributeReference(s, BooleanType, nullable = false)()
+
+ /** Creates a new AttributeReference of type byte */
+ def byte = AttributeReference(s, ByteType, nullable = false)()
+
+ /** Creates a new AttributeReference of type short */
+ def short = AttributeReference(s, ShortType, nullable = false)()
+
+ /** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()
- /** Creates a new typed attributes of type string */
+ /** Creates a new AttributeReference of type long */
+ def long = AttributeReference(s, LongType, nullable = false)()
+
+ /** Creates a new AttributeReference of type float */
+ def float = AttributeReference(s, FloatType, nullable = false)()
+
+ /** Creates a new AttributeReference of type double */
+ def double = AttributeReference(s, DoubleType, nullable = false)()
+
+ /** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal = AttributeReference(s, DecimalType, nullable = false)()
+
+ /** Creates a new AttributeReference of type timestamp */
+ def timestamp = AttributeReference(s, TimestampType, nullable = false)()
+
+ /** Creates a new AttributeReference of type binary */
+ def binary = AttributeReference(s, BinaryType, nullable = false)()
}
implicit class DslAttribute(a: AttributeReference) {
@@ -110,6 +153,8 @@ package object dsl {
// Protobuf terminology
def required = a.withNullability(false)
+
+ def at(ordinal: Int) = BoundReference(ordinal, a)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index f70e80b7f27f2..4ebf6c4584b94 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -45,14 +45,20 @@ case class BoundReference(ordinal: Int, baseReference: Attribute)
override def toString = s"$baseReference:$ordinal"
- override def apply(input: Row): Any = input(ordinal)
+ override def eval(input: Row): Any = input(ordinal)
}
+/**
+ * Used to denote operators that do their own binding of attributes internally.
+ */
+trait NoBind { self: trees.TreeNode[_] => }
+
class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
import BindReferences._
def apply(plan: TreeNode): TreeNode = {
plan.transform {
+ case n: NoBind => n.asInstanceOf[TreeNode]
case leafNode if leafNode.children.isEmpty => leafNode
case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
bindReference(e, unaryNode.children.head.output)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index c26fc3d0f305f..89226999ca005 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.Timestamp
+
import org.apache.spark.sql.catalyst.types._
/** Cast the child expression to the target data type. */
@@ -26,52 +28,169 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def toString = s"CAST($child, $dataType)"
type EvaluatedType = Any
+
+ def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
+ null
+ } else {
+ func(a.asInstanceOf[T])
+ }
- lazy val castingFunction: Any => Any = (child.dataType, dataType) match {
- case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]])
- case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes
- case (_, StringType) => a: Any => a.toString
- case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt)
- case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble)
- case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat)
- case (StringType, LongType) => a: Any => castOrNull(a, _.toLong)
- case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort)
- case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte)
- case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_))
- case (BooleanType, ByteType) => {
- case null => null
- case true => 1.toByte
- case false => 0.toByte
- }
- case (dt, IntegerType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a)
- case (dt, DoubleType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)
- case (dt, FloatType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a)
- case (dt, LongType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a)
- case (dt, ShortType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort
- case (dt, ByteType) =>
- a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte
- case (dt, DecimalType) =>
- a: Any =>
- BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a))
+ // UDFToString
+ def castToString: Any => Any = child.dataType match {
+ case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
+ case _ => nullOrCast[Any](_, _.toString)
+ }
+
+ // BinaryConverter
+ def castToBinary: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
}
- @inline
- protected def castOrNull[A](a: Any, f: String => A) =
- try f(a.asInstanceOf[String]) catch {
- case _: java.lang.NumberFormatException => null
- }
+ // UDFToBoolean
+ def castToBoolean: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, _.length() != 0)
+ case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
+ case LongType => nullOrCast[Long](_, _ != 0)
+ case IntegerType => nullOrCast[Int](_, _ != 0)
+ case ShortType => nullOrCast[Short](_, _ != 0)
+ case ByteType => nullOrCast[Byte](_, _ != 0)
+ case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
+ case DoubleType => nullOrCast[Double](_, _ != 0)
+ case FloatType => nullOrCast[Float](_, _ != 0)
+ }
+
+ // TimestampConverter
+ def castToTimestamp: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => {
+ // Throw away extra if more than 9 decimal places
+ val periodIdx = s.indexOf(".");
+ var n = s
+ if (periodIdx != -1) {
+ if (n.length() - periodIdx > 9) {
+ n = n.substring(0, periodIdx + 10)
+ }
+ }
+ try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
+ case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
+ case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
+ case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
+ case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
+ // TimestampWritable.decimalToTimestamp
+ case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
+ // TimestampWritable.doubleToTimestamp
+ case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
+ // TimestampWritable.floatToTimestamp
+ case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
+ }
+
+ private def decimalToTimestamp(d: BigDecimal) = {
+ val seconds = d.longValue()
+ val bd = (d - seconds) * (1000000000)
+ val nanos = bd.intValue()
+
+ // Convert to millis
+ val millis = seconds * 1000
+ val t = new Timestamp(millis)
+
+ // remaining fractional portion as nanos
+ t.setNanos(nanos)
+
+ t
+ }
+
+ private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)
+
+ def castToLong: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toLong catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
+ }
+
+ def castToInt: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toInt catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
+ }
+
+ def castToShort: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toShort catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
+ }
+
+ def castToByte: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toByte catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
+ }
+
+ def castToDecimal: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
+ case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
+ case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
+ }
+
+ def castToDouble: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toDouble catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
+ case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
+ }
+
+ def castToFloat: Any => Any = child.dataType match {
+ case StringType => nullOrCast[String](_, s => try s.toFloat catch {
+ case _: NumberFormatException => null
+ })
+ case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+ case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
+ case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
+ case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
+ }
+
+ def cast: Any => Any = dataType match {
+ case StringType => castToString
+ case BinaryType => castToBinary
+ case DecimalType => castToDecimal
+ case TimestampType => castToTimestamp
+ case BooleanType => castToBoolean
+ case ByteType => castToByte
+ case ShortType => castToShort
+ case IntegerType => castToInt
+ case FloatType => castToFloat
+ case LongType => castToLong
+ case DoubleType => castToDouble
+ }
- override def apply(input: Row): Any = {
- val evaluated = child.apply(input)
+ override def eval(input: Row): Any = {
+ val evaluated = child.eval(input)
if (evaluated == null) {
null
} else {
- castingFunction(evaluated)
+ cast(evaluated)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 81fd160e00ca1..f190bd0cca375 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType}
+import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}
abstract class Expression extends TreeNode[Expression] {
self: Product =>
@@ -50,7 +50,7 @@ abstract class Expression extends TreeNode[Expression] {
def references: Set[Attribute]
/** Returns the result of evaluating this expression on a given input Row */
- def apply(input: Row = null): EvaluatedType =
+ def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
/**
@@ -73,7 +73,7 @@ abstract class Expression extends TreeNode[Expression] {
*/
@inline
def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = {
- val evalE = e.apply(i)
+ val evalE = e.eval(i)
if (evalE == null) {
null
} else {
@@ -86,6 +86,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
+ /**
+ * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed
+ * to be in the same data type, and also the return type.
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
@inline
protected final def n2(
i: Row,
@@ -97,11 +102,11 @@ abstract class Expression extends TreeNode[Expression] {
throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
}
- val evalE1 = e1.apply(i)
+ val evalE1 = e1.eval(i)
if(evalE1 == null) {
null
} else {
- val evalE2 = e2.apply(i)
+ val evalE2 = e2.eval(i)
if (evalE2 == null) {
null
} else {
@@ -115,6 +120,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
+ /**
+ * Evaluation helper function for 2 Fractional children expressions. Those expressions are
+ * supposed to be in the same data type, and also the return type.
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
@inline
protected final def f2(
i: Row,
@@ -125,11 +135,11 @@ abstract class Expression extends TreeNode[Expression] {
throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
}
- val evalE1 = e1.apply(i: Row)
+ val evalE1 = e1.eval(i: Row)
if(evalE1 == null) {
null
} else {
- val evalE2 = e2.apply(i: Row)
+ val evalE2 = e2.eval(i: Row)
if (evalE2 == null) {
null
} else {
@@ -143,6 +153,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
+ /**
+ * Evaluation helper function for 2 Integral children expressions. Those expressions are
+ * supposed to be in the same data type, and also the return type.
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
@inline
protected final def i2(
i: Row,
@@ -153,11 +168,11 @@ abstract class Expression extends TreeNode[Expression] {
throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
}
- val evalE1 = e1.apply(i)
+ val evalE1 = e1.eval(i)
if(evalE1 == null) {
null
} else {
- val evalE2 = e2.apply(i)
+ val evalE2 = e2.eval(i)
if (evalE2 == null) {
null
} else {
@@ -170,6 +185,43 @@ abstract class Expression extends TreeNode[Expression] {
}
}
}
+
+ /**
+ * Evaluation helper function for 2 Comparable children expressions. Those expressions are
+ * supposed to be in the same data type, and the return type should be Integer:
+ * Negative value: 1st argument less than 2nd argument
+ * Zero: 1st argument equals 2nd argument
+ * Positive value: 1st argument greater than 2nd argument
+ *
+ * Either one of the expressions result is null, the evaluation result should be null.
+ */
+ @inline
+ protected final def c2(
+ i: Row,
+ e1: Expression,
+ e2: Expression,
+ f: ((Ordering[Any], Any, Any) => Any)): Any = {
+ if (e1.dataType != e2.dataType) {
+ throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
+ }
+
+ val evalE1 = e1.eval(i)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = e2.eval(i)
+ if (evalE2 == null) {
+ null
+ } else {
+ e1.dataType match {
+ case i: NativeType =>
+ f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean](
+ i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
+ }
+ }
+ }
}
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
@@ -179,7 +231,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def foldable = left.foldable && right.foldable
- def references = left.references ++ right.references
+ override def references = left.references ++ right.references
override def toString = s"($left $symbol $right)"
}
@@ -191,5 +243,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression]
abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] {
self: Product =>
- def references = child.references
+ override def references = child.references
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 38542d3fc7290..c9b7cea6a3e5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -27,11 +27,12 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
protected val exprArray = expressions.toArray
+
def apply(input: Row): Row = {
- val outputArray = new Array[Any](exprArray.size)
+ val outputArray = new Array[Any](exprArray.length)
var i = 0
- while (i < exprArray.size) {
- outputArray(i) = exprArray(i).apply(input)
+ while (i < exprArray.length) {
+ outputArray(i) = exprArray(i).eval(input)
i += 1
}
new GenericRow(outputArray)
@@ -57,8 +58,8 @@ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row)
def apply(input: Row): Row = {
var i = 0
- while (i < exprArray.size) {
- mutableRow(i) = exprArray(i).apply(input)
+ while (i < exprArray.length) {
+ mutableRow(i) = exprArray(i).eval(input)
i += 1
}
mutableRow
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
index 31d42b9ee71a0..0f06ea088e1a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala
@@ -19,6 +19,21 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.types.NativeType
+object Row {
+ /**
+ * This method can be used to extract fields from a [[Row]] object in a pattern match. Example:
+ * {{{
+ * import org.apache.spark.sql._
+ *
+ * val pairs = sql("SELECT key, value FROM src").rdd.map {
+ * case Row(key: Int, value: String) =>
+ * key -> value
+ * }
+ * }}}
+ */
+ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
+}
+
/**
* Represents one row of output from a relational operator. Allows both generic access by ordinal,
* which will incur boxing overhead for primitives, as well as native primitive access.
@@ -44,6 +59,16 @@ trait Row extends Seq[Any] with Serializable {
s"[${this.mkString(",")}]"
def copy(): Row
+
+ /** Returns true if there are any NULL values in this row. */
+ def anyNull: Boolean = {
+ var i = 0
+ while (i < length) {
+ if (isNullAt(i)) { return true }
+ i += 1
+ }
+ false
+ }
}
/**
@@ -187,8 +212,8 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
var i = 0
while (i < ordering.size) {
val order = ordering(i)
- val left = order.child.apply(a)
- val right = order.child.apply(b)
+ val left = order.child.eval(a)
+ val right = order.child.eval(b)
if (left == null && right == null) {
// Both null, continue looking.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index f53d8504b083f..5e089f7618e0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -27,13 +27,13 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
def references = children.flatMap(_.references).toSet
def nullable = true
- override def apply(input: Row): Any = {
+ override def eval(input: Row): Any = {
children.size match {
- case 1 => function.asInstanceOf[(Any) => Any](children(0).apply(input))
+ case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
case 2 =>
function.asInstanceOf[(Any, Any) => Any](
- children(0).apply(input),
- children(1).apply(input))
+ children(0).eval(input),
+ children(1).eval(input))
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
index 9828d0b9bd8b2..e787c59e75723 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala
@@ -30,7 +30,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression {
def references = children.toSet
def dataType = DynamicType
- override def apply(input: Row): DynamicRow = input match {
+ override def eval(input: Row): DynamicRow = input match {
// Avoid copy for generic rows.
case g: GenericRow => new DynamicRow(children, g.values)
case otherRowType => new DynamicRow(children, otherRowType.toArray)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 7303b155cae3d..5edcea14278c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -27,7 +27,7 @@ abstract class AggregateExpression extends Expression {
* Creates a new instance that can be used to compute this aggregate expression for a group
* of input rows/
*/
- def newInstance: AggregateFunction
+ def newInstance(): AggregateFunction
}
/**
@@ -43,7 +43,7 @@ case class SplitEvaluation(
partialEvaluations: Seq[NamedExpression])
/**
- * An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples.
+ * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
* These partial evaluations can then be combined to compute the actual answer.
*/
abstract class PartialAggregate extends AggregateExpression {
@@ -63,48 +63,48 @@ abstract class AggregateFunction
extends AggregateExpression with Serializable with trees.LeafNode[Expression] {
self: Product =>
- type EvaluatedType = Any
+ override type EvaluatedType = Any
/** Base should return the generic aggregate expression that this function is computing */
val base: AggregateExpression
- def references = base.references
- def nullable = base.nullable
- def dataType = base.dataType
+ override def references = base.references
+ override def nullable = base.nullable
+ override def dataType = base.dataType
def update(input: Row): Unit
- override def apply(input: Row): Any
+ override def eval(input: Row): Any
// Do we really need this?
- def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
+ override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- def references = child.references
- def nullable = false
- def dataType = IntegerType
+ override def references = child.references
+ override def nullable = false
+ override def dataType = IntegerType
override def toString = s"COUNT($child)"
- def asPartial: SplitEvaluation = {
+ override def asPartial: SplitEvaluation = {
val partialCount = Alias(Count(child), "PartialCount")()
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
}
- override def newInstance = new CountFunction(child, this)
+ override def newInstance()= new CountFunction(child, this)
}
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
- def children = expressions
- def references = expressions.flatMap(_.references).toSet
- def nullable = false
- def dataType = IntegerType
+ override def children = expressions
+ override def references = expressions.flatMap(_.references).toSet
+ override def nullable = false
+ override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
- override def newInstance = new CountDistinctFunction(expressions, this)
+ override def newInstance()= new CountDistinctFunction(expressions, this)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- def references = child.references
- def nullable = false
- def dataType = DoubleType
+ override def references = child.references
+ override def nullable = false
+ override def dataType = DoubleType
override def toString = s"AVG($child)"
override def asPartial: SplitEvaluation = {
@@ -118,13 +118,13 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
}
- override def newInstance = new AverageFunction(child, this)
+ override def newInstance()= new AverageFunction(child, this)
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- def references = child.references
- def nullable = false
- def dataType = child.dataType
+ override def references = child.references
+ override def nullable = false
+ override def dataType = child.dataType
override def toString = s"SUM($child)"
override def asPartial: SplitEvaluation = {
@@ -134,24 +134,24 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
partialSum :: Nil)
}
- override def newInstance = new SumFunction(child, this)
+ override def newInstance()= new SumFunction(child, this)
}
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
- def references = child.references
- def nullable = false
- def dataType = child.dataType
+ override def references = child.references
+ override def nullable = false
+ override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"
- override def newInstance = new SumDistinctFunction(child, this)
+ override def newInstance()= new SumDistinctFunction(child, this)
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
- def references = child.references
- def nullable = child.nullable
- def dataType = child.dataType
+ override def references = child.references
+ override def nullable = child.nullable
+ override def dataType = child.dataType
override def toString = s"FIRST($child)"
override def asPartial: SplitEvaluation = {
@@ -160,7 +160,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
First(partialFirst.toAttribute),
partialFirst :: Nil)
}
- override def newInstance = new FirstFunction(child, this)
+ override def newInstance()= new FirstFunction(child, this)
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
@@ -169,17 +169,15 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.
private var count: Long = _
- private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow))
+ private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)
-
-
private val addFunction = Add(sum, expr)
- override def apply(input: Row): Any =
- sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble
+ override def eval(input: Row): Any =
+ sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
- def update(input: Row): Unit = {
+ override def update(input: Row): Unit = {
count += 1
sum.update(addFunction, input)
}
@@ -190,28 +188,28 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
var count: Int = _
- def update(input: Row): Unit = {
- val evaluatedExpr = expr.map(_.apply(input))
+ override def update(input: Row): Unit = {
+ val evaluatedExpr = expr.map(_.eval(input))
if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) {
count += 1
}
}
- override def apply(input: Row): Any = count
+ override def eval(input: Row): Any = count
}
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null))
+ private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null))
private val addFunction = Add(sum, expr)
- def update(input: Row): Unit = {
+ override def update(input: Row): Unit = {
sum.update(addFunction, input)
}
- override def apply(input: Row): Any = sum.apply(null)
+ override def eval(input: Row): Any = sum.eval(null)
}
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
@@ -219,16 +217,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.
- val seen = new scala.collection.mutable.HashSet[Any]()
+ private val seen = new scala.collection.mutable.HashSet[Any]()
- def update(input: Row): Unit = {
- val evaluatedExpr = expr.apply(input)
+ override def update(input: Row): Unit = {
+ val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
seen += evaluatedExpr
}
}
- override def apply(input: Row): Any =
+ override def eval(input: Row): Any =
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
}
@@ -239,14 +237,14 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio
val seen = new scala.collection.mutable.HashSet[Any]()
- def update(input: Row): Unit = {
- val evaluatedExpr = expr.map(_.apply(input))
+ override def update(input: Row): Unit = {
+ val evaluatedExpr = expr.map(_.eval(input))
if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) {
seen += evaluatedExpr
}
}
- override def apply(input: Row): Any = seen.size
+ override def eval(input: Row): Any = seen.size
}
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@@ -254,11 +252,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
var result: Any = null
- def update(input: Row): Unit = {
+ override def update(input: Row): Unit = {
if (result == null) {
- result = expr.apply(input)
+ result = expr.eval(input)
}
}
- override def apply(input: Row): Any = result
+ override def eval(input: Row): Any = result
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index fba056e7c07e3..c79c1847cedf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -28,7 +28,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
def nullable = child.nullable
override def toString = s"-$child"
- override def apply(input: Row): Any = {
+ override def eval(input: Row): Any = {
n1(child, input, _.negate(_))
}
}
@@ -55,25 +55,25 @@ abstract class BinaryArithmetic extends BinaryExpression {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "+"
- override def apply(input: Row): Any = n2(input, left, right, _.plus(_, _))
+ override def eval(input: Row): Any = n2(input, left, right, _.plus(_, _))
}
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "-"
- override def apply(input: Row): Any = n2(input, left, right, _.minus(_, _))
+ override def eval(input: Row): Any = n2(input, left, right, _.minus(_, _))
}
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "*"
- override def apply(input: Row): Any = n2(input, left, right, _.times(_, _))
+ override def eval(input: Row): Any = n2(input, left, right, _.times(_, _))
}
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "/"
- override def apply(input: Row): Any = dataType match {
+ override def eval(input: Row): Any = dataType match {
case _: FractionalType => f2(input, left, right, _.div(_, _))
case _: IntegralType => i2(input, left , right, _.quot(_, _))
}
@@ -83,5 +83,5 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "%"
- override def apply(input: Row): Any = i2(input, left, right, _.rem(_, _))
+ override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index ab96618d73df7..c947155cb701c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -39,10 +39,10 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
override def toString = s"$child[$ordinal]"
- override def apply(input: Row): Any = {
+ override def eval(input: Row): Any = {
if (child.dataType.isInstanceOf[ArrayType]) {
- val baseValue = child.apply(input).asInstanceOf[Seq[_]]
- val o = ordinal.apply(input).asInstanceOf[Int]
+ val baseValue = child.eval(input).asInstanceOf[Seq[_]]
+ val o = ordinal.eval(input).asInstanceOf[Int]
if (baseValue == null) {
null
} else if (o >= baseValue.size || o < 0) {
@@ -51,8 +51,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
baseValue(o)
}
} else {
- val baseValue = child.apply(input).asInstanceOf[Map[Any, _]]
- val key = ordinal.apply(input)
+ val baseValue = child.eval(input).asInstanceOf[Map[Any, _]]
+ val key = ordinal.eval(input)
if (baseValue == null) {
null
} else {
@@ -85,8 +85,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType]
- override def apply(input: Row): Any = {
- val baseValue = child.apply(input).asInstanceOf[Row]
+ override def eval(input: Row): Any = {
+ val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e9b491b10a5f2..dd78614754e12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -35,17 +35,17 @@ import org.apache.spark.sql.catalyst.types._
* requested. The attributes produced by this function will be automatically copied anytime rules
* result in changes to the Generator or its children.
*/
-abstract class Generator extends Expression with (Row => TraversableOnce[Row]) {
+abstract class Generator extends Expression {
self: Product =>
- type EvaluatedType = TraversableOnce[Row]
+ override type EvaluatedType = TraversableOnce[Row]
- lazy val dataType =
+ override lazy val dataType =
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable))))
- def nullable = false
+ override def nullable = false
- def references = children.flatMap(_.references).toSet
+ override def references = children.flatMap(_.references).toSet
/**
* Should be overridden by specific generators. Called only once for each instance to ensure
@@ -63,7 +63,7 @@ abstract class Generator extends Expression with (Row => TraversableOnce[Row]) {
}
/** Should be implemented by child classes to perform specific Generators. */
- def apply(input: Row): TraversableOnce[Row]
+ override def eval(input: Row): TraversableOnce[Row]
/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
@@ -83,7 +83,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
- lazy val elementTypes = child.dataType match {
+ private lazy val elementTypes = child.dataType match {
case ArrayType(et) => et :: Nil
case MapType(kt,vt) => kt :: vt :: Nil
}
@@ -100,13 +100,13 @@ case class Explode(attributeNames: Seq[String], child: Expression)
}
}
- override def apply(input: Row): TraversableOnce[Row] = {
+ override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_) =>
- val inputArray = child.apply(input).asInstanceOf[Seq[Any]]
+ val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
case MapType(_, _) =>
- val inputMap = child.apply(input).asInstanceOf[Map[Any,Any]]
+ val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]]
if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index b82a12e0f754e..e15e16d633365 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.Timestamp
+
import org.apache.spark.sql.catalyst.types._
object Literal {
@@ -29,6 +31,9 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
+ case d: BigDecimal => Literal(d, DecimalType)
+ case t: Timestamp => Literal(t, TimestampType)
+ case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
}
}
@@ -52,7 +57,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
override def toString = if (value != null) value.toString else "null"
type EvaluatedType = Any
- override def apply(input: Row):Any = value
+ override def eval(input: Row):Any = value
}
// TODO: Specialize
@@ -64,8 +69,8 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf
def references = Set.empty
def update(expression: Expression, input: Row) = {
- value = expression.apply(input)
+ value = expression.eval(input)
}
- override def apply(input: Row) = value
+ override def eval(input: Row) = value
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 69c8bed309c18..eb4bc8e755284 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -79,7 +79,7 @@ case class Alias(child: Expression, name: String)
type EvaluatedType = Any
- override def apply(input: Row) = child.apply(input)
+ override def eval(input: Row) = child.eval(input)
def dataType = child.dataType
def nullable = child.nullable
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 5a47768dcb4a1..ce6d99c911ab3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -41,11 +41,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
throw new UnresolvedException(this, "Coalesce cannot have children of different types.")
}
- override def apply(input: Row): Any = {
+ override def eval(input: Row): Any = {
var i = 0
var result: Any = null
while(i < children.size && result == null) {
- result = children(i).apply(input)
+ result = children(i).eval(input)
i += 1
}
result
@@ -57,8 +57,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr
override def foldable = child.foldable
def nullable = false
- override def apply(input: Row): Any = {
- child.apply(input) == null
+ override def eval(input: Row): Any = {
+ child.eval(input) == null
}
}
@@ -68,7 +68,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E
def nullable = false
override def toString = s"IS NOT NULL $child"
- override def apply(input: Row): Any = {
- child.apply(input) != null
+ override def eval(input: Row): Any = {
+ child.eval(input) != null
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 722ff517d250e..da5b2cf5b0362 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -18,8 +18,15 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}
+import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType}
+
+object InterpretedPredicate {
+ def apply(expression: Expression): (Row => Boolean) = {
+ (r: Row) => expression.eval(r).asInstanceOf[Boolean]
+ }
+}
trait Predicate extends Expression {
self: Product =>
@@ -47,8 +54,8 @@ case class Not(child: Expression) extends Predicate with trees.UnaryNode[Express
def nullable = child.nullable
override def toString = s"NOT $child"
- override def apply(input: Row): Any = {
- child.apply(input) match {
+ override def eval(input: Row): Any = {
+ child.eval(input) match {
case null => null
case b: Boolean => !b
}
@@ -64,18 +71,18 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
def nullable = true // TODO: Figure out correct nullability semantics of IN.
override def toString = s"$value IN ${list.mkString("(", ",", ")")}"
- override def apply(input: Row): Any = {
- val evaluatedValue = value.apply(input)
- list.exists(e => e.apply(input) == evaluatedValue)
+ override def eval(input: Row): Any = {
+ val evaluatedValue = value.eval(input)
+ list.exists(e => e.eval(input) == evaluatedValue)
}
}
case class And(left: Expression, right: Expression) extends BinaryPredicate {
def symbol = "&&"
- override def apply(input: Row): Any = {
- val l = left.apply(input)
- val r = right.apply(input)
+ override def eval(input: Row): Any = {
+ val l = left.eval(input)
+ val r = right.eval(input)
if (l == false || r == false) {
false
} else if (l == null || r == null ) {
@@ -89,9 +96,9 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
case class Or(left: Expression, right: Expression) extends BinaryPredicate {
def symbol = "||"
- override def apply(input: Row): Any = {
- val l = left.apply(input)
- val r = right.apply(input)
+ override def eval(input: Row): Any = {
+ val l = left.eval(input)
+ val r = right.eval(input)
if (l == true || r == true) {
true
} else if (l == null || r == null) {
@@ -108,79 +115,31 @@ abstract class BinaryComparison extends BinaryPredicate {
case class Equals(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "="
- override def apply(input: Row): Any = {
- val l = left.apply(input)
- val r = right.apply(input)
+ override def eval(input: Row): Any = {
+ val l = left.eval(input)
+ val r = right.eval(input)
if (l == null || r == null) null else l == r
}
}
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<"
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] < r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.lt(_, _))
- }
- }
+ override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _))
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<="
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] <= r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.lteq(_, _))
- }
- }
+ override def eval(input: Row): Any = c2(input, left, right, _.lteq(_, _))
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">"
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] > r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.gt(_, _))
- }
- }
+ override def eval(input: Row): Any = c2(input, left, right, _.gt(_, _))
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">="
- override def apply(input: Row): Any = {
- if (left.dataType == StringType && right.dataType == StringType) {
- val l = left.apply(input)
- val r = right.apply(input)
- if(l == null || r == null) {
- null
- } else {
- l.asInstanceOf[String] >= r.asInstanceOf[String]
- }
- } else {
- n2(input, left, right, _.gteq(_, _))
- }
- }
+ override def eval(input: Row): Any = c2(input, left, right, _.gteq(_, _))
}
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
@@ -200,11 +159,11 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
type EvaluatedType = Any
- override def apply(input: Row): Any = {
- if (predicate(input).asInstanceOf[Boolean]) {
- trueValue.apply(input)
+ override def eval(input: Row): Any = {
+ if (predicate.eval(input).asInstanceOf[Boolean]) {
+ trueValue.eval(input)
} else {
- falseValue.apply(input)
+ falseValue.eval(input)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index e195f2ac7efd1..a27c71db1b999 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -17,11 +17,101 @@
package org.apache.spark.sql.catalyst.expressions
+import java.util.regex.Pattern
+
+import org.apache.spark.sql.catalyst.types.DataType
+import org.apache.spark.sql.catalyst.types.StringType
import org.apache.spark.sql.catalyst.types.BooleanType
-case class Like(left: Expression, right: Expression) extends BinaryExpression {
- def dataType = BooleanType
- def nullable = left.nullable // Right cannot be null.
+
+trait StringRegexExpression {
+ self: BinaryExpression =>
+
+ type EvaluatedType = Any
+
+ def escape(v: String): String
+ def matches(regex: Pattern, str: String): Boolean
+
+ def nullable: Boolean = true
+ def dataType: DataType = BooleanType
+
+ // try cache the pattern for Literal
+ private lazy val cache: Pattern = right match {
+ case x @ Literal(value: String, StringType) => compile(value)
+ case _ => null
+ }
+
+ protected def compile(str: String): Pattern = if(str == null) {
+ null
+ } else {
+ // Let it raise exception if couldn't compile the regex string
+ Pattern.compile(escape(str))
+ }
+
+ protected def pattern(str: String) = if(cache == null) compile(str) else cache
+
+ override def eval(input: Row): Any = {
+ val l = left.eval(input)
+ if (l == null) {
+ null
+ } else {
+ val r = right.eval(input)
+ if(r == null) {
+ null
+ } else {
+ val regex = pattern(r.asInstanceOf[String])
+ if(regex == null) {
+ null
+ } else {
+ matches(regex, l.asInstanceOf[String])
+ }
+ }
+ }
+ }
+}
+
+/**
+ * Simple RegEx pattern matching function
+ */
+case class Like(left: Expression, right: Expression)
+ extends BinaryExpression with StringRegexExpression {
+
def symbol = "LIKE"
+
+ // replace the _ with .{1} exactly match 1 time of any character
+ // replace the % with .*, match 0 or more times with any character
+ override def escape(v: String) = {
+ val sb = new StringBuilder()
+ var i = 0;
+ while (i < v.length) {
+ // Make a special case for "\\_" and "\\%"
+ val n = v.charAt(i);
+ if (n == '\\' && i + 1 < v.length && (v.charAt(i + 1) == '_' || v.charAt(i + 1) == '%')) {
+ sb.append(v.charAt(i + 1))
+ i += 1
+ } else {
+ if (n == '_') {
+ sb.append(".");
+ } else if (n == '%') {
+ sb.append(".*");
+ } else {
+ sb.append(Pattern.quote(Character.toString(n)));
+ }
+ }
+
+ i += 1
+ }
+
+ sb.toString()
+ }
+
+ override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
}
+case class RLike(left: Expression, right: Expression)
+ extends BinaryExpression with StringRegexExpression {
+
+ def symbol = "RLIKE"
+ override def escape(v: String): String = v
+ override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3dd6818029bcf..37b23ba58289c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -45,7 +45,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
case q: LogicalPlan => q transformExpressionsDown {
// Skip redundant folding of literals.
case l: Literal => l
- case e if e.foldable => Literal(e.apply(null), e.dataType)
+ case e if e.foldable => Literal(e.eval(null), e.dataType)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 9d16189deedfe..cfc0b0c3a8d98 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -130,7 +130,7 @@ case class Aggregate(
def references = child.references
}
-case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
def output = child.output
def references = limit.references
}
@@ -162,6 +162,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
a.nullable)(
a.exprId,
a.qualifiers)
+ case other => other
}
def references = Set.empty
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 7a45d1a1b8195..cdeb01a9656f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.types
+import java.sql.Timestamp
+
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -51,6 +53,16 @@ case object BooleanType extends NativeType {
val ordering = implicitly[Ordering[JvmType]]
}
+case object TimestampType extends NativeType {
+ type JvmType = Timestamp
+
+ @transient lazy val tag = typeTag[JvmType]
+
+ val ordering = new Ordering[JvmType] {
+ def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
+ }
+}
+
abstract class NumericType extends NativeType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala
index 9ec31689b5098..4589129cd1c90 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala
@@ -32,18 +32,5 @@ package object sql {
type Row = catalyst.expressions.Row
- object Row {
- /**
- * This method can be used to extract fields from a [[Row]] object in a pattern match. Example:
- * {{{
- * import org.apache.spark.sql._
- *
- * val pairs = sql("SELECT key, value FROM src").rdd.map {
- * case Row(key: Int, value: String) =>
- * key -> value
- * }
- * }}}
- */
- def unapplySeq(row: Row): Some[Seq[Any]] = Some(row)
- }
+ val Row = catalyst.expressions.Row
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 94894adf81202..92987405aa313 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.Timestamp
+
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
@@ -27,7 +29,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
class ExpressionEvaluationSuite extends FunSuite {
test("literals") {
- assert((Literal(1) + Literal(1)).apply(null) === 2)
+ assert((Literal(1) + Literal(1)).eval(null) === 2)
}
/**
@@ -60,7 +62,7 @@ class ExpressionEvaluationSuite extends FunSuite {
notTrueTable.foreach {
case (v, answer) =>
val expr = Not(Literal(v, BooleanType))
- val result = expr.apply(null)
+ val result = expr.eval(null)
if (result != answer)
fail(s"$expr should not evaluate to $result, expected: $answer") }
}
@@ -103,10 +105,144 @@ class ExpressionEvaluationSuite extends FunSuite {
truthTable.foreach {
case (l,r,answer) =>
val expr = op(Literal(l, BooleanType), Literal(r, BooleanType))
- val result = expr.apply(null)
+ val result = expr.eval(null)
if (result != answer)
fail(s"$expr should not evaluate to $result, expected: $answer")
}
}
}
+
+ def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = {
+ expression.eval(inputRow)
+ }
+
+ def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = {
+ val actual = try evaluate(expression, inputRow) catch {
+ case e: Exception => fail(s"Exception evaluating $expression", e)
+ }
+ if(actual != expected) {
+ val input = if(inputRow == EmptyRow) "" else s", input: $inputRow"
+ fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
+ }
+ }
+
+ test("LIKE literal Regular Expression") {
+ checkEvaluation(Literal(null, StringType).like("a"), null)
+ checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null)
+ checkEvaluation("abdef" like "abdef", true)
+ checkEvaluation("a_%b" like "a\\__b", true)
+ checkEvaluation("addb" like "a_%b", true)
+ checkEvaluation("addb" like "a\\__b", false)
+ checkEvaluation("addb" like "a%\\%b", false)
+ checkEvaluation("a_%b" like "a%\\%b", true)
+ checkEvaluation("addb" like "a%", true)
+ checkEvaluation("addb" like "**", false)
+ checkEvaluation("abc" like "a%", true)
+ checkEvaluation("abc" like "b%", false)
+ checkEvaluation("abc" like "bc%", false)
+ }
+
+ test("LIKE Non-literal Regular Expression") {
+ val regEx = 'a.string.at(0)
+ checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null)))
+ checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef")))
+ checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b")))
+ checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b")))
+ checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b")))
+ checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b")))
+ checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b")))
+ checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%")))
+ checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**")))
+ checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%")))
+ checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%")))
+ checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%")))
+ }
+
+ test("RLIKE literal Regular Expression") {
+ checkEvaluation("abdef" rlike "abdef", true)
+ checkEvaluation("abbbbc" rlike "a.*c", true)
+
+ checkEvaluation("fofo" rlike "^fo", true)
+ checkEvaluation("fo\no" rlike "^fo\no$", true)
+ checkEvaluation("Bn" rlike "^Ba*n", true)
+ checkEvaluation("afofo" rlike "fo", true)
+ checkEvaluation("afofo" rlike "^fo", false)
+ checkEvaluation("Baan" rlike "^Ba?n", false)
+ checkEvaluation("axe" rlike "pi|apa", false)
+ checkEvaluation("pip" rlike "^(pi)*$", false)
+
+ checkEvaluation("abc" rlike "^ab", true)
+ checkEvaluation("abc" rlike "^bc", false)
+ checkEvaluation("abc" rlike "^ab", true)
+ checkEvaluation("abc" rlike "^bc", false)
+
+ intercept[java.util.regex.PatternSyntaxException] {
+ evaluate("abbbbc" rlike "**")
+ }
+ }
+
+ test("RLIKE Non-literal Regular Expression") {
+ val regEx = 'a.string.at(0)
+ checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef")))
+ checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c")))
+ checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo")))
+ checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$")))
+ checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n")))
+
+ intercept[java.util.regex.PatternSyntaxException] {
+ evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**")))
+ }
+ }
+
+ test("data type casting") {
+
+ val sts = "1970-01-01 00:00:01.0"
+ val ts = Timestamp.valueOf(sts)
+
+ checkEvaluation("abdef" cast StringType, "abdef")
+ checkEvaluation("abdef" cast DecimalType, null)
+ checkEvaluation("abdef" cast TimestampType, null)
+ checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65))
+
+ checkEvaluation(Literal(1) cast LongType, 1)
+ checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1)
+ checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
+ checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
+
+ checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts)
+ checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts)
+
+ checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef")
+
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5)
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null)
+ checkEvaluation(Cast(Cast(Cast(Cast(
+ Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5)
+ checkEvaluation(Literal(true) cast IntegerType, 1)
+ checkEvaluation(Literal(false) cast IntegerType, 0)
+ checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
+ checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0)
+ checkEvaluation("23" cast DoubleType, 23)
+ checkEvaluation("23" cast IntegerType, 23)
+ checkEvaluation("23" cast FloatType, 23)
+ checkEvaluation("23" cast DecimalType, 23)
+ checkEvaluation("23" cast ByteType, 23)
+ checkEvaluation("23" cast ShortType, 23)
+ checkEvaluation("2012-12-11" cast DoubleType, null)
+ checkEvaluation(Literal(123) cast IntegerType, 123)
+
+ intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
+ }
+
+ test("timestamp") {
+ val ts1 = new Timestamp(12)
+ val ts2 = new Timestamp(123)
+ checkEvaluation(Literal("ab") < Literal("abc"), true)
+ checkEvaluation(Literal(ts1) < Literal(ts2), true)
+ }
}
+
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index e367edfb1f562..85580ed6b822f 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -30,6 +30,17 @@
jar
Spark Project SQL
http://spark.apache.org/
+
+
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
diff --git a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala
deleted file mode 100644
index f1230e7526ab1..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.language.implicitConversions
-
-import scala.reflect._
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.{Aggregator, InterruptibleIterator, Logging}
-import org.apache.spark.util.collection.AppendOnlyMap
-
-/* Implicit conversions */
-import org.apache.spark.SparkContext._
-
-/**
- * Extra functions on RDDs that perform only local operations. These can be used when data has
- * already been partitioned correctly.
- */
-private[spark] class PartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
- extends Logging
- with Serializable {
-
- /**
- * Cogroup corresponding partitions of `this` and `other`. These two RDDs should have
- * the same number of partitions. Partitions of these two RDDs are cogrouped
- * according to the indexes of partitions. If we have two RDDs and
- * each of them has n partitions, we will cogroup the partition i from `this`
- * with the partition i from `other`.
- * This function will not introduce a shuffling operation.
- */
- def cogroupLocally[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
- val cg = self.zipPartitions(other)((iter1:Iterator[(K, V)], iter2:Iterator[(K, W)]) => {
- val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
-
- val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
- if (hadVal) oldVal else Array.fill(2)(new ArrayBuffer[Any])
- }
-
- val getSeq = (k: K) => {
- map.changeValue(k, update)
- }
-
- iter1.foreach { kv => getSeq(kv._1)(0) += kv._2 }
- iter2.foreach { kv => getSeq(kv._1)(1) += kv._2 }
-
- map.iterator
- }).mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])}
-
- cg
- }
-
- /**
- * Group the values for each key within a partition of the RDD into a single sequence.
- * This function will not introduce a shuffling operation.
- */
- def groupByKeyLocally(): RDD[(K, Seq[V])] = {
- def createCombiner(v: V) = ArrayBuffer(v)
- def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
- val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner, mergeValue, _ ++ _)
- val bufs = self.mapPartitionsWithContext((context, iter) => {
- new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
- }, preservesPartitioning = true)
- bufs.asInstanceOf[RDD[(K, Seq[V])]]
- }
-
- /**
- * Join corresponding partitions of `this` and `other`.
- * If we have two RDDs and each of them has n partitions,
- * we will join the partition i from `this` with the partition i from `other`.
- * This function will not introduce a shuffling operation.
- */
- def joinLocally[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
- cogroupLocally(other).flatMapValues {
- case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
- }
- }
-}
-
-private[spark] object PartitionLocalRDDFunctions {
- implicit def rddToPartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
- new PartitionLocalRDDFunctions(rdd)
-}
-
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cf3c06acce5b0..3193787680d16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -26,8 +26,9 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
import org.apache.spark.sql.execution._
/**
@@ -79,12 +80,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd)))
/**
- * Loads a parequet file, returning the result as a [[SchemaRDD]].
+ * Loads a Parquet file, returning the result as a [[SchemaRDD]].
*
* @group userf
*/
def parquetFile(path: String): SchemaRDD =
- new SchemaRDD(this, parquet.ParquetRelation("ParquetFile", path))
+ new SchemaRDD(this, parquet.ParquetRelation(path))
/**
@@ -111,13 +112,42 @@ class SQLContext(@transient val sparkContext: SparkContext)
result
}
+ /** Returns the specified table as a SchemaRDD */
+ def table(tableName: String): SchemaRDD =
+ new SchemaRDD(this, catalog.lookupRelation(None, tableName))
+
+ /** Caches the specified table in-memory. */
+ def cacheTable(tableName: String): Unit = {
+ val currentTable = catalog.lookupRelation(None, tableName)
+ val asInMemoryRelation =
+ InMemoryColumnarTableScan(currentTable.output, executePlan(currentTable).executedPlan)
+
+ catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation))
+ }
+
+ /** Removes the specified table from the in-memory cache. */
+ def uncacheTable(tableName: String): Unit = {
+ EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match {
+ // This is kind of a hack to make sure that if this was just an RDD registered as a table,
+ // we reregister the RDD as a table.
+ case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd)) =>
+ inMem.cachedColumnBuffers.unpersist()
+ catalog.unregisterTable(None, tableName)
+ catalog.registerTable(None, tableName, SparkLogicalPlan(e))
+ case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) =>
+ inMem.cachedColumnBuffers.unpersist()
+ catalog.unregisterTable(None, tableName)
+ case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan")
+ }
+ }
+
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext = self.sparkContext
val strategies: Seq[Strategy] =
- TopK ::
+ TakeOrdered ::
PartialAggregation ::
- SparkEquiInnerJoin ::
+ HashJoin ::
ParquetOperations ::
BasicOperators ::
CartesianProduct ::
@@ -194,6 +224,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
+ def simpleString: String = stringOrError(executedPlan)
+
override def toString: String =
s"""== Logical Plan ==
|${stringOrError(analyzed)}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 770cabcb31d13..fc95781448569 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql
+import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.types.BooleanType
-import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
/**
* ALPHA COMPONENT
@@ -92,23 +92,10 @@ import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext}
*/
class SchemaRDD(
@transient val sqlContext: SQLContext,
- @transient val logicalPlan: LogicalPlan)
- extends RDD[Row](sqlContext.sparkContext, Nil) {
+ @transient protected[spark] val logicalPlan: LogicalPlan)
+ extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike {
- /**
- * A lazily computed query execution workflow. All other RDD operations are passed
- * through to the RDD that is produced by this workflow.
- *
- * We want this to be lazy because invoking the whole query optimization pipeline can be
- * expensive.
- */
- @transient
- protected[spark] lazy val queryExecution = sqlContext.executePlan(logicalPlan)
-
- override def toString =
- s"""${super.toString}
- |== Query Plan ==
- |${queryExecution.executedPlan}""".stripMargin.trim
+ def baseSchemaRDD = this
// =========================================================================================
// RDD functions: Copy the interal row representation so we present immutable data to users.
@@ -161,17 +148,17 @@ class SchemaRDD(
*
* @param otherPlan the [[SchemaRDD]] that should be joined with this one.
* @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.`
- * @param condition An optional condition for the join operation. This is equivilent to the `ON`
- * clause in standard SQL. In the case of `Inner` joins, specifying a
- * `condition` is equivilent to adding `where` clauses after the `join`.
+ * @param on An optional condition for the join operation. This is equivilent to the `ON`
+ * clause in standard SQL. In the case of `Inner` joins, specifying a
+ * `condition` is equivilent to adding `where` clauses after the `join`.
*
* @group Query
*/
def join(
otherPlan: SchemaRDD,
joinType: JoinType = Inner,
- condition: Option[Expression] = None): SchemaRDD =
- new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, condition))
+ on: Option[Expression] = None): SchemaRDD =
+ new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on))
/**
* Sorts the results by the given expressions.
@@ -208,14 +195,14 @@ class SchemaRDD(
* with the same name, for example, when peforming self-joins.
*
* {{{
- * val x = schemaRDD.where('a === 1).subquery('x)
- * val y = schemaRDD.where('a === 2).subquery('y)
+ * val x = schemaRDD.where('a === 1).as('x)
+ * val y = schemaRDD.where('a === 2).as('y)
* x.join(y).where("x.a".attr === "y.a".attr),
* }}}
*
* @group Query
*/
- def subquery(alias: Symbol) =
+ def as(alias: Symbol) =
new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan))
/**
@@ -312,31 +299,12 @@ class SchemaRDD(
sqlContext,
InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite))
- /**
- * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that
- * are written out using this method can be read back in as a SchemaRDD using the ``function
- *
- * @group schema
- */
- def saveAsParquetFile(path: String): Unit = {
- sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
- }
-
- /**
- * Registers this RDD as a temporary table using the given name. The lifetime of this temporary
- * table is tied to the [[SQLContext]] that was used to create this SchemaRDD.
- *
- * @group schema
- */
- def registerAsTable(tableName: String): Unit = {
- sqlContext.registerRDDAsTable(this, tableName)
- }
-
/**
* Returns this RDD as a SchemaRDD.
* @group schema
*/
def toSchemaRDD = this
+ /** FOR INTERNAL USE ONLY */
def analyze = sqlContext.analyzer(logicalPlan)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
new file mode 100644
index 0000000000000..3dd9897c0d3b8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
@@ -0,0 +1,65 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.plans.logical._
+
+/**
+ * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
+ */
+trait SchemaRDDLike {
+ @transient val sqlContext: SQLContext
+ @transient protected[spark] val logicalPlan: LogicalPlan
+
+ private[sql] def baseSchemaRDD: SchemaRDD
+
+ /**
+ * A lazily computed query execution workflow. All other RDD operations are passed
+ * through to the RDD that is produced by this workflow.
+ *
+ * We want this to be lazy because invoking the whole query optimization pipeline can be
+ * expensive.
+ */
+ @transient
+ protected[spark] lazy val queryExecution = sqlContext.executePlan(logicalPlan)
+
+ override def toString =
+ s"""${super.toString}
+ |== Query Plan ==
+ |${queryExecution.simpleString}""".stripMargin.trim
+
+ /**
+ * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that
+ * are written out using this method can be read back in as a SchemaRDD using the ``function
+ *
+ * @group schema
+ */
+ def saveAsParquetFile(path: String): Unit = {
+ sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
+ }
+
+ /**
+ * Registers this RDD as a temporary table using the given name. The lifetime of this temporary
+ * table is tied to the [[SQLContext]] that was used to create this SchemaRDD.
+ *
+ * @group schema
+ */
+ def registerAsTable(tableName: String): Unit = {
+ sqlContext.registerRDDAsTable(baseSchemaRDD, tableName)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
new file mode 100644
index 0000000000000..573345e42c43c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -0,0 +1,100 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.api.java
+
+import java.beans.{Introspector, PropertyDescriptor}
+
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.parquet.ParquetRelation
+import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan}
+
+/**
+ * The entry point for executing Spark SQL queries from a Java program.
+ */
+class JavaSQLContext(sparkContext: JavaSparkContext) {
+
+ val sqlContext = new SQLContext(sparkContext.sc)
+
+ /**
+ * Executes a query expressed in SQL, returning the result as a JavaSchemaRDD
+ */
+ def sql(sqlQuery: String): JavaSchemaRDD = {
+ val result = new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery))
+ // We force query optimization to happen right away instead of letting it happen lazily like
+ // when using the query DSL. This is so DDL commands behave as expected. This is only
+ // generates the RDD lineage for DML queries, but do not perform any execution.
+ result.queryExecution.toRdd
+ result
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ */
+ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = {
+ // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
+ val beanInfo = Introspector.getBeanInfo(beanClass)
+
+ val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ val schema = fields.map { property =>
+ val dataType = property.getPropertyType match {
+ case c: Class[_] if c == classOf[java.lang.String] => StringType
+ case c: Class[_] if c == java.lang.Short.TYPE => ShortType
+ case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
+ case c: Class[_] if c == java.lang.Long.TYPE => LongType
+ case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
+ case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
+ case c: Class[_] if c == java.lang.Float.TYPE => FloatType
+ case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+ }
+
+ AttributeReference(property.getName, dataType, true)()
+ }
+
+ val className = beanClass.getCanonicalName
+ val rowRdd = rdd.rdd.mapPartitions { iter =>
+ // BeanInfo is not serializable so we must rediscover it remotely for each partition.
+ val localBeanInfo = Introspector.getBeanInfo(Class.forName(className))
+ val extractors =
+ localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
+
+ iter.map { row =>
+ new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow
+ }
+ }
+ new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
+ }
+
+
+ /**
+ * Loads a parquet file, returning the result as a [[JavaSchemaRDD]].
+ */
+ def parquetFile(path: String): JavaSchemaRDD =
+ new JavaSchemaRDD(sqlContext, ParquetRelation(path))
+
+
+ /**
+ * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
+ * during the lifetime of this instance of SQLContext.
+ */
+ def registerRDDAsTable(rdd: JavaSchemaRDD, tableName: String): Unit = {
+ sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
new file mode 100644
index 0000000000000..d43d672938f51
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java
+
+import org.apache.spark.api.java.{JavaRDDLike, JavaRDD}
+import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.rdd.RDD
+
+/**
+ * An RDD of [[Row]] objects that is returned as the result of a Spark SQL query. In addition to
+ * standard RDD operations, a JavaSchemaRDD can also be registered as a table in the JavaSQLContext
+ * that was used to create. Registering a JavaSchemaRDD allows its contents to be queried in
+ * future SQL statement.
+ *
+ * @groupname schema SchemaRDD Functions
+ * @groupprio schema -1
+ * @groupname Ungrouped Base RDD Functions
+ */
+class JavaSchemaRDD(
+ @transient val sqlContext: SQLContext,
+ @transient protected[spark] val logicalPlan: LogicalPlan)
+ extends JavaRDDLike[Row, JavaRDD[Row]]
+ with SchemaRDDLike {
+
+ private[sql] val baseSchemaRDD = new SchemaRDD(sqlContext, logicalPlan)
+
+ override val classTag = scala.reflect.classTag[Row]
+
+ override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd)
+
+ val rdd = baseSchemaRDD.map(new Row(_))
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
new file mode 100644
index 0000000000000..362fe769581d7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java
+
+import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
+
+/**
+ * A result row from a SparkSQL query.
+ */
+class Row(row: ScalaRow) extends Serializable {
+
+ /** Returns the number of columns present in this Row. */
+ def length: Int = row.length
+
+ /** Returns the value of column `i`. */
+ def get(i: Int): Any =
+ row(i)
+
+ /** Returns true if value at column `i` is NULL. */
+ def isNullAt(i: Int) = get(i) == null
+
+ /**
+ * Returns the value of column `i` as an int. This function will throw an exception if the value
+ * is at `i` is not an integer, or if it is null.
+ */
+ def getInt(i: Int): Int =
+ row.getInt(i)
+
+ /**
+ * Returns the value of column `i` as a long. This function will throw an exception if the value
+ * is at `i` is not a long, or if it is null.
+ */
+ def getLong(i: Int): Long =
+ row.getLong(i)
+
+ /**
+ * Returns the value of column `i` as a double. This function will throw an exception if the
+ * value is at `i` is not a double, or if it is null.
+ */
+ def getDouble(i: Int): Double =
+ row.getDouble(i)
+
+ /**
+ * Returns the value of column `i` as a bool. This function will throw an exception if the value
+ * is at `i` is not a boolean, or if it is null.
+ */
+ def getBoolean(i: Int): Boolean =
+ row.getBoolean(i)
+
+ /**
+ * Returns the value of column `i` as a short. This function will throw an exception if the value
+ * is at `i` is not a short, or if it is null.
+ */
+ def getShort(i: Int): Short =
+ row.getShort(i)
+
+ /**
+ * Returns the value of column `i` as a byte. This function will throw an exception if the value
+ * is at `i` is not a byte, or if it is null.
+ */
+ def getByte(i: Int): Byte =
+ row.getByte(i)
+
+ /**
+ * Returns the value of column `i` as a float. This function will throw an exception if the value
+ * is at `i` is not a float, or if it is null.
+ */
+ def getFloat(i: Int): Float =
+ row.getFloat(i)
+
+ /**
+ * Returns the value of column `i` as a String. This function will throw an exception if the
+ * value is at `i` is not a String.
+ */
+ def getString(i: Int): String =
+ row.getString(i)
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index e0c98ecdf8f22..ffd4894b5213d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -21,7 +21,7 @@ import java.nio.{ByteOrder, ByteBuffer}
import org.apache.spark.sql.catalyst.types.{BinaryType, NativeType, DataType}
import org.apache.spark.sql.catalyst.expressions.MutableRow
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
@@ -41,121 +41,66 @@ private[sql] trait ColumnAccessor {
protected def underlyingBuffer: ByteBuffer
}
-private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](buffer: ByteBuffer)
+private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
+ protected val buffer: ByteBuffer,
+ protected val columnType: ColumnType[T, JvmType])
extends ColumnAccessor {
protected def initialize() {}
- def columnType: ColumnType[T, JvmType]
-
def hasNext = buffer.hasRemaining
def extractTo(row: MutableRow, ordinal: Int) {
- doExtractTo(row, ordinal)
+ columnType.setField(row, ordinal, extractSingle(buffer))
}
- protected def doExtractTo(row: MutableRow, ordinal: Int)
+ def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer)
protected def underlyingBuffer = buffer
}
private[sql] abstract class NativeColumnAccessor[T <: NativeType](
- buffer: ByteBuffer,
- val columnType: NativeColumnType[T])
- extends BasicColumnAccessor[T, T#JvmType](buffer)
+ override protected val buffer: ByteBuffer,
+ override protected val columnType: NativeColumnType[T])
+ extends BasicColumnAccessor(buffer, columnType)
with NullableColumnAccessor
+ with CompressibleColumnAccessor[T]
private[sql] class BooleanColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, BOOLEAN) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setBoolean(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, BOOLEAN)
private[sql] class IntColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, INT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setInt(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, INT)
private[sql] class ShortColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, SHORT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setShort(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, SHORT)
private[sql] class LongColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, LONG) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setLong(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, LONG)
private[sql] class ByteColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, BYTE) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setByte(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, BYTE)
private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, DOUBLE) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setDouble(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, DOUBLE)
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, FLOAT) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setFloat(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, FLOAT)
private[sql] class StringColumnAccessor(buffer: ByteBuffer)
- extends NativeColumnAccessor(buffer, STRING) {
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row.setString(ordinal, columnType.extract(buffer))
- }
-}
+ extends NativeColumnAccessor(buffer, STRING)
private[sql] class BinaryColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer)
- with NullableColumnAccessor {
-
- def columnType = BINARY
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- row(ordinal) = columnType.extract(buffer)
- }
-}
+ extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY)
+ with NullableColumnAccessor
private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
- extends BasicColumnAccessor[DataType, Array[Byte]](buffer)
- with NullableColumnAccessor {
-
- def columnType = GENERIC
-
- override protected def doExtractTo(row: MutableRow, ordinal: Int) {
- val serialized = columnType.extract(buffer)
- row(ordinal) = SparkSqlSerializer.deserialize[Any](serialized)
- }
-}
+ extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC)
+ with NullableColumnAccessor
private[sql] object ColumnAccessor {
- def apply(b: ByteBuffer): ColumnAccessor = {
- // The first 4 bytes in the buffer indicates the column type.
- val buffer = b.duplicate().order(ByteOrder.nativeOrder())
+ def apply(buffer: ByteBuffer): ColumnAccessor = {
+ // The first 4 bytes in the buffer indicate the column type.
val columnTypeId = buffer.getInt()
columnTypeId match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 3e622adfd3d6a..048ee66bff44b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnBuilder._
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder}
private[sql] trait ColumnBuilder {
/**
@@ -30,37 +30,44 @@ private[sql] trait ColumnBuilder {
*/
def initialize(initialSize: Int, columnName: String = "")
+ /**
+ * Appends `row(ordinal)` to the column builder.
+ */
def appendFrom(row: Row, ordinal: Int)
+ /**
+ * Column statistics information
+ */
+ def columnStats: ColumnStats[_, _]
+
+ /**
+ * Returns the final columnar byte buffer.
+ */
def build(): ByteBuffer
}
-private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends ColumnBuilder {
+private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
+ val columnStats: ColumnStats[T, JvmType],
+ val columnType: ColumnType[T, JvmType])
+ extends ColumnBuilder {
- private var columnName: String = _
- protected var buffer: ByteBuffer = _
+ protected var columnName: String = _
- def columnType: ColumnType[T, JvmType]
+ protected var buffer: ByteBuffer = _
override def initialize(initialSize: Int, columnName: String = "") = {
val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize
this.columnName = columnName
- buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize)
+
+ // Reserves 4 bytes for column type ID
+ buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize)
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
}
- // Have to give a concrete implementation to make mixin possible
override def appendFrom(row: Row, ordinal: Int) {
- doAppendFrom(row, ordinal)
- }
-
- // Concrete `ColumnBuilder`s can override this method to append values
- protected def doAppendFrom(row: Row, ordinal: Int)
-
- // Helper method to append primitive values (to avoid boxing cost)
- protected def appendValue(v: JvmType) {
- buffer = ensureFreeSpace(buffer, columnType.actualSize(v))
- columnType.append(v, buffer)
+ val field = columnType.getField(row, ordinal)
+ buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
+ columnType.append(field, buffer)
}
override def build() = {
@@ -69,83 +76,39 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends C
}
}
-private[sql] abstract class NativeColumnBuilder[T <: NativeType](
- val columnType: NativeColumnType[T])
- extends BasicColumnBuilder[T, T#JvmType]
+private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType])
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
with NullableColumnBuilder
-private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(BOOLEAN) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getBoolean(ordinal))
- }
-}
-
-private[sql] class IntColumnBuilder extends NativeColumnBuilder(INT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getInt(ordinal))
- }
-}
+private[sql] abstract class NativeColumnBuilder[T <: NativeType](
+ override val columnStats: NativeColumnStats[T],
+ override val columnType: NativeColumnType[T])
+ extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType)
+ with NullableColumnBuilder
+ with AllCompressionSchemes
+ with CompressibleColumnBuilder[T]
-private[sql] class ShortColumnBuilder extends NativeColumnBuilder(SHORT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getShort(ordinal))
- }
-}
+private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN)
-private[sql] class LongColumnBuilder extends NativeColumnBuilder(LONG) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getLong(ordinal))
- }
-}
+private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT)
-private[sql] class ByteColumnBuilder extends NativeColumnBuilder(BYTE) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getByte(ordinal))
- }
-}
+private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT)
-private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(DOUBLE) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getDouble(ordinal))
- }
-}
+private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG)
-private[sql] class FloatColumnBuilder extends NativeColumnBuilder(FLOAT) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getFloat(ordinal))
- }
-}
+private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE)
-private[sql] class StringColumnBuilder extends NativeColumnBuilder(STRING) {
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row.getString(ordinal))
- }
-}
+private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE)
-private[sql] class BinaryColumnBuilder
- extends BasicColumnBuilder[BinaryType.type, Array[Byte]]
- with NullableColumnBuilder {
+private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)
- def columnType = BINARY
+private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)
- override def doAppendFrom(row: Row, ordinal: Int) {
- appendValue(row(ordinal).asInstanceOf[Array[Byte]])
- }
-}
+private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY)
// TODO (lian) Add support for array, struct and map
-private[sql] class GenericColumnBuilder
- extends BasicColumnBuilder[DataType, Array[Byte]]
- with NullableColumnBuilder {
-
- def columnType = GENERIC
-
- override def doAppendFrom(row: Row, ordinal: Int) {
- val serialized = SparkSqlSerializer.serialize(row(ordinal))
- buffer = ColumnBuilder.ensureFreeSpace(buffer, columnType.actualSize(serialized))
- columnType.append(serialized, buffer)
- }
-}
+private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC)
private[sql] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
new file mode 100644
index 0000000000000..30c6bdc7912fc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.types._
+
+private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
+ /**
+ * Closed lower bound of this column.
+ */
+ def lowerBound: JvmType
+
+ /**
+ * Closed upper bound of this column.
+ */
+ def upperBound: JvmType
+
+ /**
+ * Gathers statistics information from `row(ordinal)`.
+ */
+ def gatherStats(row: Row, ordinal: Int)
+
+ /**
+ * Returns `true` if `lower <= row(ordinal) <= upper`.
+ */
+ def contains(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `row(ordinal) < upper` holds.
+ */
+ def isAbove(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `lower < row(ordinal)` holds.
+ */
+ def isBelow(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `row(ordinal) <= upper` holds.
+ */
+ def isAtOrAbove(row: Row, ordinal: Int): Boolean
+
+ /**
+ * Returns `true` if `lower <= row(ordinal)` holds.
+ */
+ def isAtOrBelow(row: Row, ordinal: Int): Boolean
+}
+
+private[sql] sealed abstract class NativeColumnStats[T <: NativeType]
+ extends ColumnStats[T, T#JvmType] {
+
+ type JvmType = T#JvmType
+
+ protected var (_lower, _upper) = initialBounds
+
+ def initialBounds: (JvmType, JvmType)
+
+ protected def columnType: NativeColumnType[T]
+
+ override def lowerBound: T#JvmType = _lower
+
+ override def upperBound: T#JvmType = _upper
+
+ override def isAtOrAbove(row: Row, ordinal: Int) = {
+ contains(row, ordinal) || isAbove(row, ordinal)
+ }
+
+ override def isAtOrBelow(row: Row, ordinal: Int) = {
+ contains(row, ordinal) || isBelow(row, ordinal)
+ }
+}
+
+private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] {
+ override def isAtOrBelow(row: Row, ordinal: Int) = true
+
+ override def isAtOrAbove(row: Row, ordinal: Int) = true
+
+ override def isBelow(row: Row, ordinal: Int) = true
+
+ override def isAbove(row: Row, ordinal: Int) = true
+
+ override def contains(row: Row, ordinal: Int) = true
+
+ override def gatherStats(row: Row, ordinal: Int) {}
+
+ override def upperBound = null.asInstanceOf[JvmType]
+
+ override def lowerBound = null.asInstanceOf[JvmType]
+}
+
+private[sql] abstract class BasicColumnStats[T <: NativeType](
+ protected val columnType: NativeColumnType[T])
+ extends NativeColumnStats[T]
+
+private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) {
+ override def initialBounds = (true, false)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) {
+ override def initialBounds = (Byte.MaxValue, Byte.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) {
+ override def initialBounds = (Short.MaxValue, Short.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class LongColumnStats extends BasicColumnStats(LONG) {
+ override def initialBounds = (Long.MaxValue, Long.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) {
+ override def initialBounds = (Double.MaxValue, Double.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
+ override def initialBounds = (Float.MaxValue, Float.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+ }
+}
+
+private[sql] object IntColumnStats {
+ val UNINITIALIZED = 0
+ val INITIALIZED = 1
+ val ASCENDING = 2
+ val DESCENDING = 3
+ val UNORDERED = 4
+}
+
+/**
+ * Statistical information for `Int` columns. More information is collected since `Int` is
+ * frequently used. Extra information include:
+ *
+ * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
+ * is applicable when searching elements.
+ * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
+ * scheme.
+ *
+ * (This two kinds of information are not used anywhere yet and might be removed later.)
+ */
+private[sql] class IntColumnStats extends BasicColumnStats(INT) {
+ import IntColumnStats._
+
+ private var orderedState = UNINITIALIZED
+ private var lastValue: Int = _
+ private var _maxDelta: Int = _
+
+ def isAscending = orderedState != DESCENDING && orderedState != UNORDERED
+ def isDescending = orderedState != ASCENDING && orderedState != UNORDERED
+ def isOrdered = isAscending || isDescending
+ def maxDelta = _maxDelta
+
+ override def initialBounds = (Int.MaxValue, Int.MinValue)
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ lowerBound < columnType.getField(row, ordinal)
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ columnType.getField(row, ordinal) < upperBound
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ val field = columnType.getField(row, ordinal)
+ lowerBound <= field && field <= upperBound
+ }
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+
+ if (field > upperBound) _upper = field
+ if (field < lowerBound) _lower = field
+
+ orderedState = orderedState match {
+ case UNINITIALIZED =>
+ lastValue = field
+ INITIALIZED
+
+ case INITIALIZED =>
+ // If all the integers in the column are the same, ordered state is set to Ascending.
+ // TODO (lian) Confirm whether this is the standard behaviour.
+ val nextState = if (field >= lastValue) ASCENDING else DESCENDING
+ _maxDelta = math.abs(field - lastValue)
+ lastValue = field
+ nextState
+
+ case ASCENDING if field < lastValue =>
+ UNORDERED
+
+ case DESCENDING if field > lastValue =>
+ UNORDERED
+
+ case state @ (ASCENDING | DESCENDING) =>
+ _maxDelta = _maxDelta.max(field - lastValue)
+ lastValue = field
+ state
+
+ case _ =>
+ orderedState
+ }
+ }
+}
+
+private[sql] class StringColumnStats extends BasicColumnStats(STRING) {
+ override def initialBounds = (null, null)
+
+ override def gatherStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+ if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field
+ if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field
+ }
+
+ override def contains(row: Row, ordinal: Int) = {
+ !(upperBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0
+ }
+ }
+
+ override def isAbove(row: Row, ordinal: Int) = {
+ !(upperBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ field.compareTo(upperBound) < 0
+ }
+ }
+
+ override def isBelow(row: Row, ordinal: Int) = {
+ !(lowerBound eq null) && {
+ val field = columnType.getField(row, ordinal)
+ lowerBound.compareTo(field) < 0
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index a452b86f0cda3..5be76890afe31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -19,7 +19,12 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.execution.SparkSqlSerializer
/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
@@ -50,10 +55,24 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
*/
def actualSize(v: JvmType): Int = defaultSize
+ /**
+ * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
+ * whenever possible.
+ */
+ def getField(row: Row, ordinal: Int): JvmType
+
+ /**
+ * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
+ * costs whenever possible.
+ */
+ def setField(row: MutableRow, ordinal: Int, value: JvmType)
+
/**
* Creates a duplicated copy of the value.
*/
def clone(v: JvmType): JvmType = v
+
+ override def toString = getClass.getSimpleName.stripSuffix("$")
}
private[sql] abstract class NativeColumnType[T <: NativeType](
@@ -65,7 +84,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType](
/**
* Scala TypeTag. Can be used to create primitive arrays and hash tables.
*/
- def scalaTag = dataType.tag
+ def scalaTag: TypeTag[dataType.JvmType] = dataType.tag
}
private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
@@ -76,6 +95,12 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
def extract(buffer: ByteBuffer) = {
buffer.getInt()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Int) {
+ row.setInt(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getInt(ordinal)
}
private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
@@ -86,6 +111,12 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getLong()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Long) {
+ row.setLong(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getLong(ordinal)
}
private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
@@ -96,6 +127,12 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
override def extract(buffer: ByteBuffer) = {
buffer.getFloat()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Float) {
+ row.setFloat(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal)
}
private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
@@ -106,6 +143,12 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
override def extract(buffer: ByteBuffer) = {
buffer.getDouble()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Double) {
+ row.setDouble(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal)
}
private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
@@ -116,6 +159,12 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
override def extract(buffer: ByteBuffer) = {
if (buffer.get() == 1) true else false
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Boolean) {
+ row.setBoolean(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal)
}
private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
@@ -126,6 +175,12 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
override def extract(buffer: ByteBuffer) = {
buffer.get()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Byte) {
+ row.setByte(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getByte(ordinal)
}
private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
@@ -136,6 +191,12 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
override def extract(buffer: ByteBuffer) = {
buffer.getShort()
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: Short) {
+ row.setShort(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getShort(ordinal)
}
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
@@ -152,6 +213,12 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
buffer.get(stringBytes, 0, length)
new String(stringBytes)
}
+
+ override def setField(row: MutableRow, ordinal: Int, value: String) {
+ row.setString(ordinal, value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = row.getString(ordinal)
}
private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
@@ -173,15 +240,27 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}
-private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16)
+private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ row(ordinal) = value
+ }
+
+ override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]]
+}
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
-private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16)
+private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
+ }
+
+ override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal))
+}
private[sql] object ColumnType {
- implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = {
+ def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
case IntegerType => INT
case LongType => LONG
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
similarity index 93%
rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index f853759e5a306..8a24733047423 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
import org.apache.spark.sql.execution.{SparkPlan, LeafNode}
import org.apache.spark.sql.Row
-/* Implicit conversions */
-import org.apache.spark.sql.columnar.ColumnType._
-
private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan)
extends LeafNode {
@@ -32,8 +29,8 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch
lazy val cachedColumnBuffers = {
val output = child.output
val cached = child.execute().mapPartitions { iterator =>
- val columnBuilders = output.map { a =>
- ColumnBuilder(a.dataType.typeId, 0, a.name)
+ val columnBuilders = output.map { attribute =>
+ ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name)
}.toArray
var row: Row = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
index 2970c609b928d..7d49ab07f7a53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
@@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
private var nextNullIndex: Int = _
private var pos: Int = 0
- abstract override def initialize() {
+ abstract override protected def initialize() {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
nullCount = nullsBuffer.getInt()
nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index 048d1f05c7df2..2a3b6fc1e46d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -22,10 +22,18 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.spark.sql.Row
/**
- * Builds a nullable column. The byte buffer of a nullable column contains:
- * - 4 bytes for the null count (number of nulls)
- * - positions for each null, in ascending order
- * - the non-null data (column data type, compression type, data...)
+ * A stackable trait used for building byte buffer for a column containing null values. Memory
+ * layout of the final byte buffer is:
+ * {{{
+ * .----------------------- Column type ID (4 bytes)
+ * | .------------------- Null count N (4 bytes)
+ * | | .--------------- Null positions (4 x N bytes, empty if null count is zero)
+ * | | | .--------- Non-null elements
+ * V V V V
+ * +---+---+-----+---------+
+ * | | | ... | ... ... |
+ * +---+---+-----+---------+
+ * }}}
*/
private[sql] trait NullableColumnBuilder extends ColumnBuilder {
private var nulls: ByteBuffer = _
@@ -59,19 +67,8 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
nulls.limit(nullDataLen)
nulls.rewind()
- // Column type ID is moved to the front, follows the null count, then non-null data
- //
- // +---------+
- // | 4 bytes | Column type ID
- // +---------+
- // | 4 bytes | Null count
- // +---------+
- // | ... | Null positions (if null count is not zero)
- // +---------+
- // | ... | Non-null part (without column type ID)
- // +---------+
val buffer = ByteBuffer
- .allocate(4 + nullDataLen + nonNulls.limit)
+ .allocate(4 + 4 + nullDataLen + nonNulls.remaining())
.order(ByteOrder.nativeOrder())
.putInt(typeId)
.putInt(nullCount)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
new file mode 100644
index 0000000000000..878cb84de106f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor}
+
+private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor {
+ this: NativeColumnAccessor[T] =>
+
+ private var decoder: Decoder[T] = _
+
+ abstract override protected def initialize() = {
+ super.initialize()
+ decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType)
+ }
+
+ abstract override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
new file mode 100644
index 0000000000000..3ac4b358ddf83
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.{ByteBuffer, ByteOrder}
+
+import org.apache.spark.sql.{Logging, Row}
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder}
+
+/**
+ * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of
+ * the final byte buffer is:
+ * {{{
+ * .--------------------------- Column type ID (4 bytes)
+ * | .----------------------- Null count N (4 bytes)
+ * | | .------------------- Null positions (4 x N bytes, empty if null count is zero)
+ * | | | .------------- Compression scheme ID (4 bytes)
+ * | | | | .--------- Compressed non-null elements
+ * V V V V V
+ * +---+---+-----+---+---------+
+ * | | | ... | | ... ... |
+ * +---+---+-----+---+---------+
+ * \-----------/ \-----------/
+ * header body
+ * }}}
+ */
+private[sql] trait CompressibleColumnBuilder[T <: NativeType]
+ extends ColumnBuilder with Logging {
+
+ this: NativeColumnBuilder[T] with WithCompressionSchemes =>
+
+ import CompressionScheme._
+
+ val compressionEncoders = schemes.filter(_.supports(columnType)).map(_.encoder)
+
+ protected def isWorthCompressing(encoder: Encoder) = {
+ encoder.compressionRatio < 0.8
+ }
+
+ private def gatherCompressibilityStats(row: Row, ordinal: Int) {
+ val field = columnType.getField(row, ordinal)
+
+ var i = 0
+ while (i < compressionEncoders.length) {
+ compressionEncoders(i).gatherCompressibilityStats(field, columnType)
+ i += 1
+ }
+ }
+
+ abstract override def appendFrom(row: Row, ordinal: Int) {
+ super.appendFrom(row, ordinal)
+ gatherCompressibilityStats(row, ordinal)
+ }
+
+ abstract override def build() = {
+ val rawBuffer = super.build()
+ val encoder = {
+ val candidate = compressionEncoders.minBy(_.compressionRatio)
+ if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
+ }
+
+ val headerSize = columnHeaderSize(rawBuffer)
+ val compressedSize = if (encoder.compressedSize == 0) {
+ rawBuffer.limit - headerSize
+ } else {
+ encoder.compressedSize
+ }
+
+ // Reserves 4 bytes for compression scheme ID
+ val compressedBuffer = ByteBuffer
+ .allocate(headerSize + 4 + compressedSize)
+ .order(ByteOrder.nativeOrder)
+
+ copyColumnHeader(rawBuffer, compressedBuffer)
+
+ logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
+ encoder.compress(rawBuffer, compressedBuffer, columnType)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
new file mode 100644
index 0000000000000..d3a4ac8df926b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
+
+private[sql] trait Encoder {
+ def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {}
+
+ def compressedSize: Int
+
+ def uncompressedSize: Int
+
+ def compressionRatio: Double = {
+ if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
+ }
+
+ def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]): ByteBuffer
+}
+
+private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType]
+
+private[sql] trait CompressionScheme {
+ def typeId: Int
+
+ def supports(columnType: ColumnType[_, _]): Boolean
+
+ def encoder: Encoder
+
+ def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
+}
+
+private[sql] trait WithCompressionSchemes {
+ def schemes: Seq[CompressionScheme]
+}
+
+private[sql] trait AllCompressionSchemes extends WithCompressionSchemes {
+ override val schemes: Seq[CompressionScheme] = {
+ Seq(PassThrough, RunLengthEncoding, DictionaryEncoding)
+ }
+}
+
+private[sql] object CompressionScheme {
+ def apply(typeId: Int): CompressionScheme = typeId match {
+ case PassThrough.typeId => PassThrough
+ case _ => throw new UnsupportedOperationException()
+ }
+
+ def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) {
+ // Writes column type ID
+ to.putInt(from.getInt())
+
+ // Writes null count
+ val nullCount = from.getInt()
+ to.putInt(nullCount)
+
+ // Writes null positions
+ var i = 0
+ while (i < nullCount) {
+ to.putInt(from.getInt())
+ i += 1
+ }
+ }
+
+ def columnHeaderSize(columnBuffer: ByteBuffer): Int = {
+ val header = columnBuffer.duplicate()
+ val nullCount = header.getInt(4)
+ // Column type ID + null count + null positions
+ 4 + 4 + 4 * nullCount
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
new file mode 100644
index 0000000000000..dc2c153faf8ad
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.runtimeMirror
+
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+
+private[sql] case object PassThrough extends CompressionScheme {
+ override val typeId = 0
+
+ override def supports(columnType: ColumnType[_, _]) = true
+
+ override def encoder = new this.Encoder
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder(buffer, columnType)
+ }
+
+ class Encoder extends compression.Encoder {
+ override def uncompressedSize = 0
+
+ override def compressedSize = 0
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ // Writes compression type ID and copies raw contents
+ to.putInt(PassThrough.typeId).put(from).rewind()
+ to
+ }
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ override def next() = columnType.extract(buffer)
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
+
+private[sql] case object RunLengthEncoding extends CompressionScheme {
+ override def typeId = 1
+
+ override def encoder = new this.Encoder
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder(buffer, columnType)
+ }
+
+ override def supports(columnType: ColumnType[_, _]) = columnType match {
+ case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true
+ case _ => false
+ }
+
+ class Encoder extends compression.Encoder {
+ private var _uncompressedSize = 0
+ private var _compressedSize = 0
+
+ // Using `MutableRow` to store the last value to avoid boxing/unboxing cost.
+ private val lastValue = new GenericMutableRow(1)
+ private var lastRun = 0
+
+ override def uncompressedSize = _uncompressedSize
+
+ override def compressedSize = _compressedSize
+
+ override def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {
+
+ val actualSize = columnType.actualSize(value)
+ _uncompressedSize += actualSize
+
+ if (lastValue.isNullAt(0)) {
+ columnType.setField(lastValue, 0, value)
+ lastRun = 1
+ _compressedSize += actualSize + 4
+ } else {
+ if (columnType.getField(lastValue, 0) == value) {
+ lastRun += 1
+ } else {
+ _compressedSize += actualSize + 4
+ columnType.setField(lastValue, 0, value)
+ lastRun = 1
+ }
+ }
+ }
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ to.putInt(RunLengthEncoding.typeId)
+
+ if (from.hasRemaining) {
+ var currentValue = columnType.extract(from)
+ var currentRun = 1
+
+ while (from.hasRemaining) {
+ val value = columnType.extract(from)
+
+ if (value == currentValue) {
+ currentRun += 1
+ } else {
+ // Writes current run
+ columnType.append(currentValue, to)
+ to.putInt(currentRun)
+
+ // Resets current run
+ currentValue = value
+ currentRun = 1
+ }
+ }
+
+ // Writes the last run
+ columnType.append(currentValue, to)
+ to.putInt(currentRun)
+ }
+
+ to.rewind()
+ to
+ }
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ private var run = 0
+ private var valueCount = 0
+ private var currentValue: T#JvmType = _
+
+ override def next() = {
+ if (valueCount == run) {
+ currentValue = columnType.extract(buffer)
+ run = buffer.getInt()
+ valueCount = 1
+ } else {
+ valueCount += 1
+ }
+
+ currentValue
+ }
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
+
+private[sql] case object DictionaryEncoding extends CompressionScheme {
+ override def typeId: Int = 2
+
+ // 32K unique values allowed
+ private val MAX_DICT_SIZE = Short.MaxValue - 1
+
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new this.Decoder[T](buffer, columnType)
+ }
+
+ override def encoder = new this.Encoder
+
+ override def supports(columnType: ColumnType[_, _]) = columnType match {
+ case INT | LONG | STRING => true
+ case _ => false
+ }
+
+ class Encoder extends compression.Encoder{
+ // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
+ // overflows.
+ private var _uncompressedSize = 0
+
+ // If the number of distinct elements is too large, we discard the use of dictionary encoding
+ // and set the overflow flag to true.
+ private var overflow = false
+
+ // Total number of elements.
+ private var count = 0
+
+ // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself.
+ private var values = new mutable.ArrayBuffer[Any](1024)
+
+ // The dictionary that maps a value to the encoded short integer.
+ private val dictionary = mutable.HashMap.empty[Any, Short]
+
+ // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int`
+ // to store dictionary element count.
+ private var dictionarySize = 4
+
+ override def gatherCompressibilityStats[T <: NativeType](
+ value: T#JvmType,
+ columnType: ColumnType[T, T#JvmType]) {
+
+ if (!overflow) {
+ val actualSize = columnType.actualSize(value)
+ count += 1
+ _uncompressedSize += actualSize
+
+ if (!dictionary.contains(value)) {
+ if (dictionary.size < MAX_DICT_SIZE) {
+ val clone = columnType.clone(value)
+ values += clone
+ dictionarySize += actualSize
+ dictionary(clone) = dictionary.size.toShort
+ } else {
+ overflow = true
+ values.clear()
+ dictionary.clear()
+ }
+ }
+ }
+ }
+
+ override def compress[T <: NativeType](
+ from: ByteBuffer,
+ to: ByteBuffer,
+ columnType: ColumnType[T, T#JvmType]) = {
+
+ if (overflow) {
+ throw new IllegalStateException(
+ "Dictionary encoding should not be used because of dictionary overflow.")
+ }
+
+ to.putInt(DictionaryEncoding.typeId)
+ .putInt(dictionary.size)
+
+ var i = 0
+ while (i < values.length) {
+ columnType.append(values(i).asInstanceOf[T#JvmType], to)
+ i += 1
+ }
+
+ while (from.hasRemaining) {
+ to.putShort(dictionary(columnType.extract(from)))
+ }
+
+ to.rewind()
+ to
+ }
+
+ override def uncompressedSize = _uncompressedSize
+
+ override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2
+ }
+
+ class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
+ extends compression.Decoder[T] {
+
+ private val dictionary = {
+ // TODO Can we clean up this mess? Maybe move this to `DataType`?
+ implicit val classTag = {
+ val mirror = runtimeMirror(getClass.getClassLoader)
+ ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
+ }
+
+ Array.fill(buffer.getInt()) {
+ columnType.extract(buffer)
+ }
+ }
+
+ override def next() = dictionary(buffer.getShort())
+
+ override def hasNext = buffer.hasRemaining
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 869673b1fe978..450c142c0baa4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -76,7 +76,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
*/
object AddExchange extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
- val numPartitions = 8
+ val numPartitions = 150
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index e902e6ced521d..cff4887936ae1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -36,10 +36,10 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
- def output =
+ override def output =
if (join) child.output ++ generator.output else generator.output
- def execute() = {
+ override def execute() = {
if (join) {
child.execute().mapPartitions { iter =>
val nullValues = Seq.fill(generator.output.size)(Literal(null))
@@ -52,7 +52,7 @@ case class Generate(
val joinedRow = new JoinedRow
iter.flatMap {row =>
- val outputRows = generator(row)
+ val outputRows = generator.eval(row)
if (outer && outputRows.isEmpty) {
outerProjection(row) :: Nil
} else {
@@ -61,7 +61,7 @@ case class Generate(
}
}
} else {
- child.execute().mapPartitions(iter => iter.flatMap(generator))
+ child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row)))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index acb1ee83a72f6..daa423cb8ea1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.catalyst.plans.{QueryPlan, logical}
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
self: Product =>
@@ -69,6 +70,8 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan)
SparkLogicalPlan(
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
+ case InMemoryColumnarTableScan(output, child) =>
+ InMemoryColumnarTableScan(output.map(_.newInstance), child)
case _ => sys.error("Multiple instance of the same relation detected.")
}).asInstanceOf[this.type]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 1c3196ae2e7b6..d8e1b970c1d88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -32,6 +32,13 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
kryo.setRegistrationRequired(false)
kryo.register(classOf[MutablePair[_, _]])
kryo.register(classOf[Array[Any]])
+ // This is kinda hacky...
+ kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
+ kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
+ kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
@@ -70,3 +77,20 @@ class BigDecimalSerializer extends Serializer[BigDecimal] {
BigDecimal(input.readString())
}
}
+
+/**
+ * Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
+ * them as `Array[(k,v)]`.
+ */
+class MapSerializer extends Serializer[Map[_,_]] {
+ def write(kryo: Kryo, output: Output, map: Map[_,_]) {
+ kryo.writeObject(output, map.flatMap(e => Seq(e._1, e._2)).toArray)
+ }
+
+ def read(kryo: Kryo, input: Input, tpe: Class[Map[_,_]]): Map[_,_] = {
+ kryo.readObject(input, classOf[Array[Any]])
+ .sliding(2,2)
+ .map { case Array(k,v) => (k,v) }
+ .toMap
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 86f9d3e0fa954..fe8bd5a508820 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
- object SparkEquiInnerJoin extends Strategy {
+ object HashJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) =>
logger.debug(s"Considering join: ${predicates ++ condition}")
@@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val leftKeys = joinKeys.map(_._1)
val rightKeys = joinKeys.map(_._2)
- val joinOp = execution.SparkEquiInnerJoin(
- leftKeys, rightKeys, planLater(left), planLater(right))
+ val joinOp = execution.HashJoin(
+ leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
// Make sure other conditions are met if present.
if (otherPredicates.nonEmpty) {
@@ -158,10 +158,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case other => other
}
- object TopK extends Strategy {
+ object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) =>
- execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil
+ case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
+ execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
case _ => Nil
}
}
@@ -171,10 +171,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// TODO: need to support writing to other types of files. Unify the below code paths.
case logical.WriteToFile(path, child) =>
val relation =
- ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, None)
- InsertIntoParquetTable(relation, planLater(child))(sparkContext) :: Nil
+ ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
+ InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
- InsertIntoParquetTable(table, planLater(child))(sparkContext) :: Nil
+ InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
case PhysicalOperation(projectList, filters, relation: ParquetRelation) =>
// TODO: Should be pushing down filters as well.
pruneFilterProject(
@@ -213,8 +213,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
sparkContext.parallelize(data.map(r =>
new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
execution.ExistingRdd(output, dataAsRdd) :: Nil
- case logical.StopAfter(IntegerLiteral(limit), child) =>
- execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil
+ case logical.Limit(IntegerLiteral(limit), child) =>
+ execution.Limit(limit, planLater(child))(sparkContext) :: Nil
case Unions(unionChildren) =>
execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
case logical.Generate(generator, join, outer, _, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
index 8515a18f18c55..0890faa33b507 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala
@@ -17,14 +17,13 @@
package org.apache.spark.sql.execution
+import java.util.HashMap
+
import org.apache.spark.SparkContext
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
-/* Implicit conversions */
-import org.apache.spark.rdd.PartitionLocalRDDFunctions._
-
/**
* Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each
* group.
@@ -40,7 +39,7 @@ case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sc: SparkContext)
- extends UnaryNode {
+ extends UnaryNode with NoBind {
override def requiredChildDistribution =
if (partial) {
@@ -55,61 +54,149 @@ case class Aggregate(
override def otherCopyArgs = sc :: Nil
+ // HACK: Generators don't correctly preserve their output through serializations so we grab
+ // out child's output attributes statically here.
+ val childOutput = child.output
+
def output = aggregateExpressions.map(_.toAttribute)
- /* Replace all aggregate expressions with spark functions that will compute the result. */
- def createAggregateImplementations() = aggregateExpressions.map { agg =>
- val impl = agg transform {
- case a: AggregateExpression => a.newInstance
+ /**
+ * An aggregate that needs to be computed for each row in a group.
+ *
+ * @param unbound Unbound version of this aggregate, used for result substitution.
+ * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer.
+ * @param resultAttribute An attribute used to refer to the result of this aggregate in the final
+ * output.
+ */
+ case class ComputedAggregate(
+ unbound: AggregateExpression,
+ aggregate: AggregateExpression,
+ resultAttribute: AttributeReference)
+
+ /** A list of aggregates that need to be computed for each group. */
+ @transient
+ lazy val computedAggregates = aggregateExpressions.flatMap { agg =>
+ agg.collect {
+ case a: AggregateExpression =>
+ ComputedAggregate(
+ a,
+ BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
+ AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
}
+ }.toArray
+
+ /** The schema of the result of all aggregate evaluations */
+ @transient
+ lazy val computedSchema = computedAggregates.map(_.resultAttribute)
+
+ /** Creates a new aggregate buffer for a group. */
+ def newAggregateBuffer(): Array[AggregateFunction] = {
+ val buffer = new Array[AggregateFunction](computedAggregates.length)
+ var i = 0
+ while (i < computedAggregates.length) {
+ buffer(i) = computedAggregates(i).aggregate.newInstance()
+ i += 1
+ }
+ buffer
+ }
- val remainingAttributes = impl.collect { case a: Attribute => a }
- // If any references exist that are not inside agg functions then the must be grouping exprs
- // in this case we must rebind them to the grouping tuple.
- if (remainingAttributes.nonEmpty) {
- val unaliasedAggregateExpr = agg transform { case Alias(c, _) => c }
-
- // An exact match with a grouping expression
- val exactGroupingExpr = groupingExpressions.indexOf(unaliasedAggregateExpr) match {
- case -1 => None
- case ordinal => Some(BoundReference(ordinal, Alias(impl, "AGGEXPR")().toAttribute))
- }
+ /** Named attributes used to substitute grouping attributes into the final result. */
+ @transient
+ lazy val namedGroups = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne.toAttribute
+ case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute
+ }
- exactGroupingExpr.getOrElse(
- sys.error(s"$agg is not in grouping expressions: $groupingExpressions"))
- } else {
- impl
+ /**
+ * A map of substitutions that are used to insert the aggregate expressions and grouping
+ * expression into the final result expression.
+ */
+ @transient
+ lazy val resultMap =
+ (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap
+
+ /**
+ * Substituted version of aggregateExpressions expressions which are used to compute final
+ * output rows given a group and the result of all aggregate computations.
+ */
+ @transient
+ lazy val resultExpressions = aggregateExpressions.map { agg =>
+ agg.transform {
+ case e: Expression if resultMap.contains(e) => resultMap(e)
}
}
def execute() = attachTree(this, "execute") {
- // TODO: If the child of it is an [[catalyst.execution.Exchange]],
- // do not evaluate the groupingExpressions again since we have evaluated it
- // in the [[catalyst.execution.Exchange]].
- val grouped = child.execute().mapPartitions { iter =>
- val buildGrouping = new Projection(groupingExpressions)
- iter.map(row => (buildGrouping(row), row.copy()))
- }.groupByKeyLocally()
-
- val result = grouped.map { case (group, rows) =>
- val aggImplementations = createAggregateImplementations()
-
- // Pull out all the functions so we can feed each row into them.
- val aggFunctions = aggImplementations.flatMap(_ collect { case f: AggregateFunction => f })
-
- rows.foreach { row =>
- aggFunctions.foreach(_.update(row))
+ if (groupingExpressions.isEmpty) {
+ child.execute().mapPartitions { iter =>
+ val buffer = newAggregateBuffer()
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ var i = 0
+ while (i < buffer.length) {
+ buffer(i).update(currentRow)
+ i += 1
+ }
+ }
+ val resultProjection = new Projection(resultExpressions, computedSchema)
+ val aggregateResults = new GenericMutableRow(computedAggregates.length)
+
+ var i = 0
+ while (i < buffer.length) {
+ aggregateResults(i) = buffer(i).eval(EmptyRow)
+ i += 1
+ }
+
+ Iterator(resultProjection(aggregateResults))
}
- buildRow(aggImplementations.map(_.apply(group)))
- }
-
- // TODO: THIS BREAKS PIPELINING, DOUBLE COMPUTES THE ANSWER, AND USES TOO MUCH MEMORY...
- if (groupingExpressions.isEmpty && result.count == 0) {
- // When there there is no output to the Aggregate operator, we still output an empty row.
- val aggImplementations = createAggregateImplementations()
- sc.makeRDD(buildRow(aggImplementations.map(_.apply(null))) :: Nil)
} else {
- result
+ child.execute().mapPartitions { iter =>
+ val hashTable = new HashMap[Row, Array[AggregateFunction]]
+ val groupingProjection = new MutableProjection(groupingExpressions, childOutput)
+
+ var currentRow: Row = null
+ while (iter.hasNext) {
+ currentRow = iter.next()
+ val currentGroup = groupingProjection(currentRow)
+ var currentBuffer = hashTable.get(currentGroup)
+ if (currentBuffer == null) {
+ currentBuffer = newAggregateBuffer()
+ hashTable.put(currentGroup.copy(), currentBuffer)
+ }
+
+ var i = 0
+ while (i < currentBuffer.length) {
+ currentBuffer(i).update(currentRow)
+ i += 1
+ }
+ }
+
+ new Iterator[Row] {
+ private[this] val hashTableIter = hashTable.entrySet().iterator()
+ private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
+ private[this] val resultProjection =
+ new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
+ private[this] val joinedRow = new JoinedRow
+
+ override final def hasNext: Boolean = hashTableIter.hasNext
+
+ override final def next(): Row = {
+ val currentEntry = hashTableIter.next()
+ val currentGroup = currentEntry.getKey
+ val currentBuffer = currentEntry.getValue
+
+ var i = 0
+ while (i < currentBuffer.length) {
+ // Evaluating an aggregate buffer returns the result. No row is required since we
+ // already added all rows in the group using update.
+ aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
+ i += 1
+ }
+ resultProjection(joinedRow(aggregateResults, currentGroup))
+ }
+ }
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 65cb8f8becefa..ab2e62463764a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -19,65 +19,88 @@ package org.apache.spark.sql.execution
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext
-
+import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
+import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.util.MutablePair
+
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
- def output = projectList.map(_.toAttribute)
+ override def output = projectList.map(_.toAttribute)
- def execute() = child.execute().mapPartitions { iter =>
+ override def execute() = child.execute().mapPartitions { iter =>
@transient val reusableProjection = new MutableProjection(projectList)
iter.map(reusableProjection)
}
}
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
- def output = child.output
+ override def output = child.output
- def execute() = child.execute().mapPartitions { iter =>
- iter.filter(condition.apply(_).asInstanceOf[Boolean])
+ override def execute() = child.execute().mapPartitions { iter =>
+ iter.filter(condition.eval(_).asInstanceOf[Boolean])
}
}
case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
extends UnaryNode {
- def output = child.output
+ override def output = child.output
// TODO: How to pick seed?
- def execute() = child.execute().sample(withReplacement, fraction, seed)
+ override def execute() = child.execute().sample(withReplacement, fraction, seed)
}
case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan {
// TODO: attributes output by union should be distinct for nullability purposes
- def output = children.head.output
- def execute() = sc.union(children.map(_.execute()))
+ override def output = children.head.output
+ override def execute() = sc.union(children.map(_.execute()))
override def otherCopyArgs = sc :: Nil
}
-case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements. Note that the implementation is different depending on whether
+ * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
+ * this operator uses Spark's take method on the Spark driver. If it is not terminal or is
+ * invoked using execute, we first take the limit on each partition, and then repartition all the
+ * data to a single partition to compute the global limit.
+ */
+case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
+ // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
+ // partition local limit -> exchange into one partition -> partition local limit again
+
override def otherCopyArgs = sc :: Nil
- def output = child.output
+ override def output = child.output
override def executeCollect() = child.execute().map(_.copy()).take(limit)
- // TODO: Terminal split should be implemented differently from non-terminal split.
- // TODO: Pick num splits based on |limit|.
- def execute() = sc.makeRDD(executeCollect(), 1)
+ override def execute() = {
+ val rdd = child.execute().mapPartitions { iter =>
+ val mutablePair = new MutablePair[Boolean, Row]()
+ iter.take(limit).map(row => mutablePair.update(false, row))
+ }
+ val part = new HashPartitioner(1)
+ val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
+ shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
+ shuffled.mapPartitions(_.take(limit).map(_._2))
+ }
}
-case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
- (@transient sc: SparkContext) extends UnaryNode {
+/**
+ * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
+ * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
+ * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
+ */
+case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
+ (@transient sc: SparkContext) extends UnaryNode {
override def otherCopyArgs = sc :: Nil
- def output = child.output
+ override def output = child.output
@transient
lazy val ordering = new RowOrdering(sortOrder)
@@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- def execute() = sc.makeRDD(executeCollect(), 1)
+ override def execute() = sc.makeRDD(executeCollect(), 1)
}
@@ -101,7 +124,7 @@ case class Sort(
@transient
lazy val ordering = new RowOrdering(sortOrder)
- def execute() = attachTree(this, "sort") {
+ override def execute() = attachTree(this, "sort") {
// TODO: Optimize sorting operation?
child.execute()
.mapPartitions(
@@ -109,7 +132,7 @@ case class Sort(
preservesPartitioning = true)
}
- def output = child.output
+ override def output = child.output
}
object ExistingRdd {
@@ -130,6 +153,6 @@ object ExistingRdd {
}
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
- def execute() = rdd
+ override def execute() = rdd
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index f0d21143ba5d1..c89dae9358bf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -17,21 +17,22 @@
package org.apache.spark.sql.execution
-import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, BitSet}
-import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
-import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
-import org.apache.spark.rdd.PartitionLocalRDDFunctions._
+sealed abstract class BuildSide
+case object BuildLeft extends BuildSide
+case object BuildRight extends BuildSide
-case class SparkEquiInnerJoin(
+case class HashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
+ buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
@@ -40,33 +41,93 @@ case class SparkEquiInnerJoin(
override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ val (buildPlan, streamedPlan) = buildSide match {
+ case BuildLeft => (left, right)
+ case BuildRight => (right, left)
+ }
+
+ val (buildKeys, streamedKeys) = buildSide match {
+ case BuildLeft => (leftKeys, rightKeys)
+ case BuildRight => (rightKeys, leftKeys)
+ }
+
def output = left.output ++ right.output
- def execute() = attachTree(this, "execute") {
- val leftWithKeys = left.execute().mapPartitions { iter =>
- val generateLeftKeys = new Projection(leftKeys, left.output)
- iter.map(row => (generateLeftKeys(row), row.copy()))
- }
+ @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
+ @transient lazy val streamSideKeyGenerator =
+ () => new MutableProjection(streamedKeys, streamedPlan.output)
- val rightWithKeys = right.execute().mapPartitions { iter =>
- val generateRightKeys = new Projection(rightKeys, right.output)
- iter.map(row => (generateRightKeys(row), row.copy()))
- }
+ def execute() = {
- // Do the join.
- val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
- // Drop join keys and merge input tuples.
- joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
- }
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
+ // TODO: Use Spark's HashMap implementation.
+ val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
+ var currentRow: Row = null
+
+ // Create a mapping of buildKeys -> rows
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if(!rowKey.anyNull) {
+ val existingMatchList = hashTable.get(rowKey)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new ArrayBuffer[Row]()
+ hashTable.put(rowKey, newMatchList)
+ newMatchList
+ } else {
+ existingMatchList
+ }
+ matchList += currentRow.copy()
+ }
+ }
+
+ new Iterator[Row] {
+ private[this] var currentStreamedRow: Row = _
+ private[this] var currentHashMatches: ArrayBuffer[Row] = _
+ private[this] var currentMatchPosition: Int = -1
- /**
- * Filters any rows where the any of the join keys is null, ensuring three-valued
- * logic for the equi-join conditions.
- */
- protected def filterNulls(rdd: RDD[(Row, Row)]) =
- rdd.filter {
- case (key: Seq[_], _) => !key.exists(_ == null)
+ // Mutable per row objects.
+ private[this] val joinRow = new JoinedRow
+
+ private[this] val joinKeys = streamSideKeyGenerator()
+
+ override final def hasNext: Boolean =
+ (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
+ (streamIter.hasNext && fetchNext())
+
+ override final def next() = {
+ val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+ currentMatchPosition += 1
+ ret
+ }
+
+ /**
+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
+ *
+ * @return true if the search is successful, and false the streamed iterator runs out of
+ * tuples.
+ */
+ private final def fetchNext(): Boolean = {
+ currentHashMatches = null
+ currentMatchPosition = -1
+
+ while (currentHashMatches == null && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ if (!joinKeys(currentStreamedRow).anyNull) {
+ currentHashMatches = hashTable.get(joinKeys.currentValue)
+ }
+ }
+
+ if (currentHashMatches == null) {
+ false
+ } else {
+ currentMatchPosition = 0
+ true
+ }
+ }
+ }
}
+ }
}
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
@@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin(
def right = broadcast
@transient lazy val boundCondition =
- condition
- .map(c => BindReferences.bindReference(c, left.output ++ right.output))
- .getOrElse(Literal(true))
+ InterpretedPredicate(
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true)))
def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
- val matchedRows = new mutable.ArrayBuffer[Row]
- val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size)
+ val matchedRows = new ArrayBuffer[Row]
+ // TODO: Use Spark's BitSet.
+ val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow
streamedIter.foreach { streamedRow =>
@@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
while (i < broadcastedRelation.value.size) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
- if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) {
+ if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matchedRows += buildRow(streamedRow ++ broadcastedRow)
matched = true
includedBroadcastTuples += i
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 2b825f84ee910..505ad0a2c77c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -17,30 +17,29 @@
package org.apache.spark.sql.parquet
-import java.io.{IOException, FileNotFoundException}
-
-import scala.collection.JavaConversions._
+import java.io.IOException
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.permission.FsAction
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.permission.FsAction
import org.apache.hadoop.mapreduce.Job
-import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
import parquet.hadoop.util.ContextUtil
-import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter}
+import parquet.hadoop.{ParquetOutputFormat, Footer, ParquetFileWriter, ParquetFileReader}
+import parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata}
import parquet.io.api.{Binary, RecordConsumer}
+import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType, MessageTypeParser}
import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
import parquet.schema.Type.Repetition
-import parquet.schema.{MessageType, MessageTypeParser}
-import parquet.schema.{PrimitiveType => ParquetPrimitiveType}
-import parquet.schema.{Type => ParquetType}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
-import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
import org.apache.spark.sql.catalyst.types._
+// Implicits
+import scala.collection.JavaConversions._
+
/**
* Relation that consists of data stored in a Parquet columnar format.
*
@@ -48,39 +47,76 @@ import org.apache.spark.sql.catalyst.types._
* of using this class directly.
*
* {{{
- * val parquetRDD = sqlContext.parquetFile("path/to/parequet.file")
+ * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file")
* }}}
*
- * @param tableName The name of the relation that can be used in queries.
* @param path The path to the Parquet file.
*/
-case class ParquetRelation(tableName: String, path: String) extends BaseRelation {
+private[sql] case class ParquetRelation(val path: String)
+ extends LeafNode with MultiInstanceRelation {
+ self: Product =>
- /** Schema derived from ParquetFile **/
+ /** Schema derived from ParquetFile */
def parquetSchema: MessageType =
ParquetTypesConverter
.readMetaData(new Path(path))
.getFileMetaData
.getSchema
- /** Attributes **/
- val attributes =
+ /** Attributes */
+ override val output =
ParquetTypesConverter
- .convertToAttributes(parquetSchema)
+ .convertToAttributes(parquetSchema)
- /** Output **/
- override val output = attributes
+ override def newInstance = ParquetRelation(path).asInstanceOf[this.type]
- // Parquet files have no concepts of keys, therefore no Partitioner
- // Note: we could allow Block level access; needs to be thought through
- override def isPartitioned = false
+ // Equals must also take into account the output attributes so that we can distinguish between
+ // different instances of the same relation,
+ override def equals(other: Any) = other match {
+ case p: ParquetRelation =>
+ p.path == path && p.output == output
+ case _ => false
+ }
}
-object ParquetRelation {
+private[sql] object ParquetRelation {
+
+ def enableLogForwarding() {
+ // Note: Parquet does not use forwarding to parent loggers which
+ // is required for the JUL-SLF4J bridge to work. Also there is
+ // a default logger that appends to Console which needs to be
+ // reset.
+ import org.slf4j.bridge.SLF4JBridgeHandler
+ import java.util.logging.Logger
+ import java.util.logging.LogManager
+
+ val loggerNames = Seq(
+ "parquet.hadoop.ColumnChunkPageWriteStore",
+ "parquet.hadoop.InternalParquetRecordWriter",
+ "parquet.hadoop.ParquetRecordReader",
+ "parquet.hadoop.ParquetInputFormat",
+ "parquet.hadoop.ParquetOutputFormat",
+ "parquet.hadoop.ParquetFileReader",
+ "parquet.hadoop.InternalParquetRecordReader",
+ "parquet.hadoop.codec.CodecConfig")
+ LogManager.getLogManager.reset()
+ SLF4JBridgeHandler.install()
+ for(name <- loggerNames) {
+ val logger = Logger.getLogger(name)
+ logger.setParent(Logger.getLogger(Logger.GLOBAL_LOGGER_NAME))
+ logger.setUseParentHandlers(true)
+ }
+ }
// The element type for the RDDs that this relation maps to.
type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+ // The compression type
+ type CompressionType = parquet.hadoop.metadata.CompressionCodecName
+
+ // The default compression
+ val defaultCompression = CompressionCodecName.GZIP
+
/**
* Creates a new ParquetRelation and underlying Parquetfile for the given LogicalPlan. Note that
* this is used inside [[org.apache.spark.sql.execution.SparkStrategies SparkStrategies]] to
@@ -89,24 +125,39 @@ object ParquetRelation {
*
* @param pathString The directory the Parquetfile will be stored in.
* @param child The child node that will be used for extracting the schema.
- * @param conf A configuration configuration to be used.
- * @param tableName The name of the resulting relation.
- * @return An empty ParquetRelation inferred metadata.
+ * @param conf A configuration to be used.
+ * @return An empty ParquetRelation with inferred metadata.
*/
def create(pathString: String,
child: LogicalPlan,
- conf: Configuration,
- tableName: Option[String]): ParquetRelation = {
+ conf: Configuration): ParquetRelation = {
if (!child.resolved) {
throw new UnresolvedException[LogicalPlan](
child,
"Attempt to create Parquet table from unresolved child (when schema is not available)")
}
+ createEmpty(pathString, child.output, conf)
+ }
- val name = s"${tableName.getOrElse(child.nodeName)}_parquet"
+ /**
+ * Creates an empty ParquetRelation and underlying Parquetfile that only
+ * consists of the Metadata for the given schema.
+ *
+ * @param pathString The directory the Parquetfile will be stored in.
+ * @param attributes The schema of the relation.
+ * @param conf A configuration to be used.
+ * @return An empty ParquetRelation.
+ */
+ def createEmpty(pathString: String,
+ attributes: Seq[Attribute],
+ conf: Configuration): ParquetRelation = {
val path = checkPath(pathString, conf)
- ParquetTypesConverter.writeMetaData(child.output, path, conf)
- new ParquetRelation(name, path.toString)
+ if (conf.get(ParquetOutputFormat.COMPRESSION) == null) {
+ conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name())
+ }
+ ParquetRelation.enableLogForwarding()
+ ParquetTypesConverter.writeMetaData(attributes, path, conf)
+ new ParquetRelation(path.toString)
}
private def checkPath(pathStr: String, conf: Configuration): Path = {
@@ -132,7 +183,7 @@ object ParquetRelation {
}
}
-object ParquetTypesConverter {
+private[parquet] object ParquetTypesConverter {
def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match {
// for now map binary to string type
// TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema
@@ -231,6 +282,7 @@ object ParquetTypesConverter {
extraMetadata,
"Spark")
+ ParquetRelation.enableLogForwarding()
ParquetFileWriter.writeMetadataFile(
conf,
path,
@@ -257,16 +309,24 @@ object ParquetTypesConverter {
throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath")
}
val path = origPath.makeQualified(fs)
+ if (!fs.getFileStatus(path).isDir) {
+ throw new IllegalArgumentException(
+ s"Expected $path for be a directory with Parquet files/metadata")
+ }
+ ParquetRelation.enableLogForwarding()
val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)
+ // if this is a new table that was just created we will find only the metadata file
if (fs.exists(metadataPath) && fs.isFile(metadataPath)) {
- // TODO: improve exception handling, etc.
ParquetFileReader.readFooter(conf, metadataPath)
} else {
- if (!fs.exists(path) || !fs.isFile(path)) {
- throw new FileNotFoundException(
- s"Could not find file ${path.toString} when trying to read metadata")
+ // there may be one or more Parquet files in the given directory
+ val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path))
+ // TODO: for now we assume that all footers (if there is more than one) have identical
+ // metadata; we may want to add a check here at some point
+ if (footers.size() == 0) {
+ throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")
}
- ParquetFileReader.readFooter(conf, path)
+ footers(0).getParquetMetadata
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 7285f5b88b9bf..d5846baa72ada 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -24,26 +24,29 @@ import java.util.Date
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter}
-import parquet.hadoop.util.ContextUtil
import parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat}
+import parquet.hadoop.util.ContextUtil
import parquet.io.InvalidRecordException
import parquet.schema.MessageType
+import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
-import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
/**
* Parquet table scan operator. Imports the file that backs the given
* [[ParquetRelation]] as a RDD[Row].
*/
case class ParquetTableScan(
- @transient output: Seq[Attribute],
- @transient relation: ParquetRelation,
- @transient columnPruningPred: Option[Expression])(
+ // note: output cannot be transient, see
+ // https://issues.apache.org/jira/browse/SPARK-1367
+ output: Seq[Attribute],
+ relation: ParquetRelation,
+ columnPruningPred: Option[Expression])(
@transient val sc: SparkContext)
extends LeafNode {
@@ -53,6 +56,12 @@ case class ParquetTableScan(
job,
classOf[org.apache.spark.sql.parquet.RowReadSupport])
val conf: Configuration = ContextUtil.getConfiguration(job)
+ val fileList = FileSystemHelper.listFiles(relation.path, conf)
+ // add all paths in the directory but skip "hidden" ones such
+ // as "_SUCCESS" and "_metadata"
+ for (path <- fileList if !path.getName.startsWith("_")) {
+ NewFileInputFormat.addInputPath(job, path)
+ }
conf.set(
RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA,
ParquetTypesConverter.convertFromAttributes(output).toString)
@@ -63,14 +72,12 @@ case class ParquetTableScan(
``FilteredRecordReader`` (via Configuration, for example). Simple
filter-rows-by-column-values however should be supported.
*/
- sc.newAPIHadoopFile(
- relation.path,
- classOf[ParquetInputFormat[Row]],
- classOf[Void], classOf[Row],
- conf)
+ sc.newAPIHadoopRDD(conf, classOf[ParquetInputFormat[Row]], classOf[Void], classOf[Row])
.map(_._2)
}
+ override def otherCopyArgs = sc :: Nil
+
/**
* Applies a (candidate) projection.
*
@@ -108,15 +115,31 @@ case class ParquetTableScan(
}
}
+/**
+ * Operator that acts as a sink for queries on RDDs and can be used to
+ * store the output inside a directory of Parquet files. This operator
+ * is similar to Hive's INSERT INTO TABLE operation in the sense that
+ * one can choose to either overwrite or append to a directory. Note
+ * that consecutive insertions to the same table must have compatible
+ * (source) schemas.
+ *
+ * WARNING: EXPERIMENTAL! InsertIntoParquetTable with overwrite=false may
+ * cause data corruption in the case that multiple users try to append to
+ * the same table simultaneously. Inserting into a table that was
+ * previously generated by other means (e.g., by creating an HDFS
+ * directory and importing Parquet files generated by other tools) may
+ * cause unpredicted behaviour and therefore results in a RuntimeException
+ * (only detected via filename pattern so will not catch all cases).
+ */
case class InsertIntoParquetTable(
- @transient relation: ParquetRelation,
- @transient child: SparkPlan)(
+ relation: ParquetRelation,
+ child: SparkPlan,
+ overwrite: Boolean = false)(
@transient val sc: SparkContext)
extends UnaryNode with SparkHadoopMapReduceUtil {
/**
- * Inserts all the rows in the Parquet file. Note that OVERWRITE is implicit, since
- * Parquet files are write-once.
+ * Inserts all rows into the Parquet file.
*/
override def execute() = {
// TODO: currently we do not check whether the "schema"s are compatible
@@ -135,19 +158,21 @@ case class InsertIntoParquetTable(
classOf[org.apache.spark.sql.parquet.RowWriteSupport])
// TODO: move that to function in object
- val conf = job.getConfiguration
+ val conf = ContextUtil.getConfiguration(job)
conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString)
val fspath = new Path(relation.path)
val fs = fspath.getFileSystem(conf)
- try {
- fs.delete(fspath, true)
- } catch {
- case e: IOException =>
- throw new IOException(
- s"Unable to clear output directory ${fspath.toString} prior"
- + s" to InsertIntoParquetTable:\n${e.toString}")
+ if (overwrite) {
+ try {
+ fs.delete(fspath, true)
+ } catch {
+ case e: IOException =>
+ throw new IOException(
+ s"Unable to clear output directory ${fspath.toString} prior"
+ + s" to InsertIntoParquetTable:\n${e.toString}")
+ }
}
saveAsHadoopFile(childRdd, relation.path.toString, conf)
@@ -157,6 +182,8 @@ case class InsertIntoParquetTable(
override def output = child.output
+ override def otherCopyArgs = sc :: Nil
+
// based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]]
// TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2?
// .. then we could use the default one and could use [[MutablePair]]
@@ -167,15 +194,21 @@ case class InsertIntoParquetTable(
conf: Configuration) {
val job = new Job(conf)
val keyType = classOf[Void]
- val outputFormatType = classOf[parquet.hadoop.ParquetOutputFormat[Row]]
job.setOutputKeyClass(keyType)
job.setOutputValueClass(classOf[Row])
- val wrappedConf = new SerializableWritable(job.getConfiguration)
NewFileOutputFormat.setOutputPath(job, new Path(path))
+ val wrappedConf = new SerializableWritable(job.getConfiguration)
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
val stageId = sc.newRddId()
+ val taskIdOffset =
+ if (overwrite) 1
+ else {
+ FileSystemHelper
+ .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1
+ }
+
def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
@@ -184,7 +217,7 @@ case class InsertIntoParquetTable(
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
- val format = outputFormatType.newInstance
+ val format = new AppendingParquetOutputFormat(taskIdOffset)
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
val writer = format.getRecordWriter(hadoopContext)
@@ -196,7 +229,7 @@ case class InsertIntoParquetTable(
committer.commitTask(hadoopContext)
return 1
}
- val jobFormat = outputFormatType.newInstance
+ val jobFormat = new AppendingParquetOutputFormat(taskIdOffset)
/* apparently we need a TaskAttemptID to construct an OutputCommitter;
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
@@ -210,3 +243,55 @@ case class InsertIntoParquetTable(
}
}
+// TODO: this will be able to append to directories it created itself, not necessarily
+// to imported ones
+private[parquet] class AppendingParquetOutputFormat(offset: Int)
+ extends parquet.hadoop.ParquetOutputFormat[Row] {
+ // override to accept existing directories as valid output directory
+ override def checkOutputSpecs(job: JobContext): Unit = {}
+
+ // override to choose output filename so not overwrite existing ones
+ override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
+ val taskId: TaskID = context.getTaskAttemptID.getTaskID
+ val partition: Int = taskId.getId
+ val filename = s"part-r-${partition + offset}.parquet"
+ val committer: FileOutputCommitter =
+ getOutputCommitter(context).asInstanceOf[FileOutputCommitter]
+ new Path(committer.getWorkPath, filename)
+ }
+}
+
+private[parquet] object FileSystemHelper {
+ def listFiles(pathStr: String, conf: Configuration): Seq[Path] = {
+ val origPath = new Path(pathStr)
+ val fs = origPath.getFileSystem(conf)
+ if (fs == null) {
+ throw new IllegalArgumentException(
+ s"ParquetTableOperations: Path $origPath is incorrectly formatted")
+ }
+ val path = origPath.makeQualified(fs)
+ if (!fs.exists(path) || !fs.getFileStatus(path).isDir) {
+ throw new IllegalArgumentException(
+ s"ParquetTableOperations: path $path does not exist or is not a directory")
+ }
+ fs.listStatus(path).map(_.getPath)
+ }
+
+ // finds the maximum taskid in the output file names at the given path
+ def findMaxTaskId(pathStr: String, conf: Configuration): Int = {
+ val files = FileSystemHelper.listFiles(pathStr, conf)
+ // filename pattern is part-r-.parquet
+ val nameP = new scala.util.matching.Regex("""part-r-(\d{1,}).parquet""", "taskid")
+ val hiddenFileP = new scala.util.matching.Regex("_.*")
+ files.map(_.getName).map {
+ case nameP(taskid) => taskid.toInt
+ case hiddenFileP() => 0
+ case other: String => {
+ sys.error("ERROR: attempting to append to set of Parquet files and found file" +
+ s"that does not match name pattern: $other")
+ 0
+ }
+ case _ => 0
+ }.reduceLeft((a, b) => if (a < b) b else a)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index c21e400282004..84b1b4609458b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -35,7 +35,8 @@ import org.apache.spark.sql.catalyst.types._
*
*@param root The root group converter for the record.
*/
-class RowRecordMaterializer(root: CatalystGroupConverter) extends RecordMaterializer[Row] {
+private[parquet] class RowRecordMaterializer(root: CatalystGroupConverter)
+ extends RecordMaterializer[Row] {
def this(parquetSchema: MessageType) =
this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema)))
@@ -48,14 +49,14 @@ class RowRecordMaterializer(root: CatalystGroupConverter) extends RecordMaterial
/**
* A `parquet.hadoop.api.ReadSupport` for Row objects.
*/
-class RowReadSupport extends ReadSupport[Row] with Logging {
+private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging {
override def prepareForRead(
conf: Configuration,
stringMap: java.util.Map[String, String],
fileSchema: MessageType,
readContext: ReadContext): RecordMaterializer[Row] = {
- log.debug(s"preparing for read with schema ${fileSchema.toString}")
+ log.debug(s"preparing for read with file schema $fileSchema")
new RowRecordMaterializer(readContext.getRequestedSchema)
}
@@ -67,20 +68,20 @@ class RowReadSupport extends ReadSupport[Row] with Logging {
configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString)
val requested_schema =
MessageTypeParser.parseMessageType(requested_schema_string)
-
- log.debug(s"read support initialized for original schema ${requested_schema.toString}")
+ log.debug(s"read support initialized for requested schema $requested_schema")
+ ParquetRelation.enableLogForwarding()
new ReadContext(requested_schema, keyValueMetaData)
}
}
-object RowReadSupport {
+private[parquet] object RowReadSupport {
val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema"
}
/**
* A `parquet.hadoop.api.WriteSupport` for Row ojects.
*/
-class RowWriteSupport extends WriteSupport[Row] with Logging {
+private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
def setSchema(schema: MessageType, configuration: Configuration) {
// for testing
this.schema = schema
@@ -104,6 +105,8 @@ class RowWriteSupport extends WriteSupport[Row] with Logging {
override def init(configuration: Configuration): WriteSupport.WriteContext = {
schema = if (schema == null) getSchema(configuration) else schema
attributes = ParquetTypesConverter.convertToAttributes(schema)
+ log.debug(s"write support initialized for requested schema $schema")
+ ParquetRelation.enableLogForwarding()
new WriteSupport.WriteContext(
schema,
new java.util.HashMap[java.lang.String, java.lang.String]())
@@ -111,10 +114,16 @@ class RowWriteSupport extends WriteSupport[Row] with Logging {
override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
writer = recordConsumer
+ log.debug(s"preparing for write with schema $schema")
}
// TODO: add groups (nested fields)
override def write(record: Row): Unit = {
+ if (attributes.size > record.size) {
+ throw new IndexOutOfBoundsException(
+ s"Trying to write more fields than contained in row (${attributes.size}>${record.size})")
+ }
+
var index = 0
writer.startMessage()
while(index < attributes.size) {
@@ -130,7 +139,7 @@ class RowWriteSupport extends WriteSupport[Row] with Logging {
}
}
-object RowWriteSupport {
+private[parquet] object RowWriteSupport {
val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema"
}
@@ -139,7 +148,7 @@ object RowWriteSupport {
*
* @param schema The corresponding Catalyst schema in the form of a list of attributes.
*/
-class CatalystGroupConverter(
+private[parquet] class CatalystGroupConverter(
schema: Seq[Attribute],
protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter {
@@ -177,13 +186,12 @@ class CatalystGroupConverter(
* @param parent The parent group converter.
* @param fieldIndex The index inside the record.
*/
-class CatalystPrimitiveConverter(
+private[parquet] class CatalystPrimitiveConverter(
parent: CatalystGroupConverter,
fieldIndex: Int) extends PrimitiveConverter {
// TODO: consider refactoring these together with ParquetTypesConverter
override def addBinary(value: Binary): Unit =
- // TODO: fix this once a setBinary will become available in MutableRow
- parent.getCurrentRecord.setByte(fieldIndex, value.getBytes.apply(0))
+ parent.getCurrentRecord.update(fieldIndex, value.getBytes)
override def addBoolean(value: Boolean): Unit =
parent.getCurrentRecord.setBoolean(fieldIndex, value)
@@ -208,10 +216,9 @@ class CatalystPrimitiveConverter(
* @param parent The parent group converter.
* @param fieldIndex The index inside the record.
*/
-class CatalystPrimitiveStringConverter(
+private[parquet] class CatalystPrimitiveStringConverter(
parent: CatalystGroupConverter,
fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) {
override def addBinary(value: Binary): Unit =
parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8)
}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index 3340c3ff81f0a..728e3dd1dc02b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -26,7 +26,7 @@ import parquet.hadoop.util.ContextUtil
import parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.util.Utils
object ParquetTestData {
@@ -64,13 +64,13 @@ object ParquetTestData {
"mylong:Long"
)
- val testFile = getTempFilePath("testParquetFile").getCanonicalFile
+ val testDir = Utils.createTempDir()
- lazy val testData = new ParquetRelation("testData", testFile.toURI.toString)
+ lazy val testData = new ParquetRelation(testDir.toURI.toString)
def writeFile() = {
- testFile.delete
- val path: Path = new Path(testFile.toURI)
+ testDir.delete
+ val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet"))
val job = new Job()
val configuration: Configuration = ContextUtil.getConfiguration(job)
val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)
diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties
index 7bb6789bd33a5..dffd15a61838b 100644
--- a/sql/core/src/test/resources/log4j.properties
+++ b/sql/core/src/test/resources/log4j.properties
@@ -45,8 +45,6 @@ log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF
log4j.additivity.hive.ql.metadata.Hive=false
log4j.logger.hive.ql.metadata.Hive=OFF
-# Parquet logging
-parquet.hadoop.InternalParquetRecordReader=WARN
-log4j.logger.parquet.hadoop.InternalParquetRecordReader=WARN
-parquet.hadoop.ParquetInputFormat=WARN
-log4j.logger.parquet.hadoop.ParquetInputFormat=WARN
+# Parquet related logging
+log4j.logger.parquet.hadoop=WARN
+log4j.logger.org.apache.spark.sql.parquet=INFO
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
new file mode 100644
index 0000000000000..7c6a642278226
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.FunSuite
+import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+
+class CachedTableSuite extends QueryTest {
+ TestData // Load test tables.
+
+ test("read from cached table and uncache") {
+ TestSQLContext.cacheTable("testData")
+
+ checkAnswer(
+ TestSQLContext.table("testData"),
+ testData.collect().toSeq
+ )
+
+ TestSQLContext.table("testData").queryExecution.analyzed match {
+ case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
+ case noCache => fail(s"No cache node found in plan $noCache")
+ }
+
+ TestSQLContext.uncacheTable("testData")
+
+ checkAnswer(
+ TestSQLContext.table("testData"),
+ testData.collect().toSeq
+ )
+
+ TestSQLContext.table("testData").queryExecution.analyzed match {
+ case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
+ fail(s"Table still cached after uncache: $cachePlan")
+ case noCache => // Table uncached successfully
+ }
+ }
+
+ test("correct error on uncache of non-cached table") {
+ intercept[IllegalArgumentException] {
+ TestSQLContext.uncacheTable("testData")
+ }
+ }
+
+ test("SELECT Star Cached Table") {
+ TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar")
+ TestSQLContext.cacheTable("selectStar")
+ TestSQLContext.sql("SELECT * FROM selectStar")
+ TestSQLContext.uncacheTable("selectStar")
+ }
+
+ test("Self-join cached") {
+ TestSQLContext.cacheTable("testData")
+ TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key")
+ TestSQLContext.uncacheTable("testData")
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 2524a37cbac13..be0f4a4c73b36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -119,8 +119,8 @@ class DslQuerySuite extends QueryTest {
}
test("inner join, where, multiple matches") {
- val x = testData2.where('a === 1).subquery('x)
- val y = testData2.where('a === 1).subquery('y)
+ val x = testData2.where('a === 1).as('x)
+ val y = testData2.where('a === 1).as('y)
checkAnswer(
x.join(y).where("x.a".attr === "y.a".attr),
(1,1,1,1) ::
@@ -131,8 +131,8 @@ class DslQuerySuite extends QueryTest {
}
test("inner join, no matches") {
- val x = testData2.where('a === 1).subquery('x)
- val y = testData2.where('a === 2).subquery('y)
+ val x = testData2.where('a === 1).as('x)
+ val y = testData2.where('a === 2).as('y)
checkAnswer(
x.join(y).where("x.a".attr === "y.a".attr),
Nil)
@@ -140,8 +140,8 @@ class DslQuerySuite extends QueryTest {
test("big inner join, 4 matches per row") {
val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
- val bigDataX = bigData.subquery('x)
- val bigDataY = bigData.subquery('y)
+ val bigDataX = bigData.as('x)
+ val bigDataY = bigData.as('y)
checkAnswer(
bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
@@ -181,8 +181,8 @@ class DslQuerySuite extends QueryTest {
}
test("full outer join") {
- val left = upperCaseData.where('N <= 4).subquery('left)
- val right = upperCaseData.where('N >= 3).subquery('right)
+ val left = upperCaseData.where('N <= 4).as('left)
+ val right = upperCaseData.where('N >= 3).as('right)
checkAnswer(
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index fa4a1d5189ea6..4c4fd6dbbedb4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -216,4 +216,31 @@ class SQLQuerySuite extends QueryTest {
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
}
+
+ test("select with table name as qualifier") {
+ checkAnswer(
+ sql("SELECT testData.value FROM testData WHERE testData.key = 1"),
+ Seq(Seq("1")))
+ }
+
+ test("inner join ON with table name as qualifier") {
+ checkAnswer(
+ sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"),
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")))
+ }
+
+ test("qualified select with inner join ON with table name as qualifier") {
+ checkAnswer(
+ sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " +
+ "ON lowerCaseData.n = upperCaseData.N"),
+ Seq(
+ (1, "A"),
+ (2, "B"),
+ (3, "C"),
+ (4, "D")))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
new file mode 100644
index 0000000000000..65eae3357a21e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.sql.Timestamp
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class ReflectData(
+ stringField: String,
+ intField: Int,
+ longField: Long,
+ floatField: Float,
+ doubleField: Double,
+ shortField: Short,
+ byteField: Byte,
+ booleanField: Boolean,
+ decimalField: BigDecimal,
+ timestampField: Timestamp,
+ seqInt: Seq[Int])
+
+case class ReflectBinary(data: Array[Byte])
+
+class ScalaReflectionRelationSuite extends FunSuite {
+ test("query case class RDD") {
+ val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ BigDecimal(1), new Timestamp(12345), Seq(1,2,3))
+ val rdd = sparkContext.parallelize(data :: Nil)
+ rdd.registerAsTable("reflectData")
+
+ assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
+ }
+
+ // Equality is broken for Arrays, so we test that separately.
+ test("query binary data") {
+ val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
+ rdd.registerAsTable("reflectBinary")
+
+ val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
+ assert(result.toSeq === Seq[Byte](1))
+ }
+}
\ No newline at end of file
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
new file mode 100644
index 0000000000000..def0e046a3831
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java
+
+import scala.beans.BeanProperty
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.sql.test.TestSQLContext
+
+// Implicits
+import scala.collection.JavaConversions._
+
+class PersonBean extends Serializable {
+ @BeanProperty
+ var name: String = _
+
+ @BeanProperty
+ var age: Int = _
+}
+
+class JavaSQLSuite extends FunSuite {
+ val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext)
+ val javaSqlCtx = new JavaSQLContext(javaCtx)
+
+ test("schema from JavaBeans") {
+ val person = new PersonBean
+ person.setName("Michael")
+ person.setAge(29)
+
+ val rdd = javaCtx.parallelize(person :: Nil)
+ val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[PersonBean])
+
+ schemaRDD.registerAsTable("people")
+ javaSqlCtx.sql("SELECT * FROM people").collect()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
new file mode 100644
index 0000000000000..78640b876d4aa
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types._
+
+class ColumnStatsSuite extends FunSuite {
+ testColumnStats(classOf[BooleanColumnStats], BOOLEAN)
+ testColumnStats(classOf[ByteColumnStats], BYTE)
+ testColumnStats(classOf[ShortColumnStats], SHORT)
+ testColumnStats(classOf[IntColumnStats], INT)
+ testColumnStats(classOf[LongColumnStats], LONG)
+ testColumnStats(classOf[FloatColumnStats], FLOAT)
+ testColumnStats(classOf[DoubleColumnStats], DOUBLE)
+ testColumnStats(classOf[StringColumnStats], STRING)
+
+ def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]](
+ columnStatsClass: Class[U],
+ columnType: NativeColumnType[T]) {
+
+ val columnStatsName = columnStatsClass.getSimpleName
+
+ test(s"$columnStatsName: empty") {
+ val columnStats = columnStatsClass.newInstance()
+ expectResult(columnStats.initialBounds, "Wrong initial bounds") {
+ (columnStats.lowerBound, columnStats.upperBound)
+ }
+ }
+
+ test(s"$columnStatsName: non-empty") {
+ import ColumnarTestUtils._
+
+ val columnStats = columnStatsClass.newInstance()
+ val rows = Seq.fill(10)(makeRandomRow(columnType))
+ rows.foreach(columnStats.gatherStats(_, 0))
+
+ val values = rows.map(_.head.asInstanceOf[T#JvmType])
+ val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]
+
+ expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
+ expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 2d431affbcfcc..1d3608ed2d9ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -19,46 +19,56 @@ package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
-import scala.util.Random
-
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
class ColumnTypeSuite extends FunSuite {
- val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC)
+ val DEFAULT_BUFFER_SIZE = 512
test("defaultSize") {
- val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16)
+ val checks = Map(
+ INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
+ BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16)
- columnTypes.zip(defaultSize).foreach { case (columnType, size) =>
- assert(columnType.defaultSize === size)
+ checks.foreach { case (columnType, expectedSize) =>
+ expectResult(expectedSize, s"Wrong defaultSize for $columnType") {
+ columnType.defaultSize
+ }
}
}
test("actualSize") {
- val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11)
- val actualSizes = Seq(
- INT.actualSize(Int.MaxValue),
- SHORT.actualSize(Short.MaxValue),
- LONG.actualSize(Long.MaxValue),
- BYTE.actualSize(Byte.MaxValue),
- DOUBLE.actualSize(Double.MaxValue),
- FLOAT.actualSize(Float.MaxValue),
- STRING.actualSize("hello"),
- BINARY.actualSize(new Array[Byte](4)),
- GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a"))))
-
- expectedSizes.zip(actualSizes).foreach { case (expected, actual) =>
- assert(expected === actual)
+ def checkActualSize[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType],
+ value: JvmType,
+ expected: Int) {
+
+ expectResult(expected, s"Wrong actualSize for $columnType") {
+ columnType.actualSize(value)
+ }
}
+
+ checkActualSize(INT, Int.MaxValue, 4)
+ checkActualSize(SHORT, Short.MaxValue, 2)
+ checkActualSize(LONG, Long.MaxValue, 8)
+ checkActualSize(BYTE, Byte.MaxValue, 1)
+ checkActualSize(DOUBLE, Double.MaxValue, 8)
+ checkActualSize(FLOAT, Float.MaxValue, 4)
+ checkActualSize(BOOLEAN, true, 1)
+ checkActualSize(STRING, "hello", 4 + 5)
+
+ val binary = Array.fill[Byte](4)(0: Byte)
+ checkActualSize(BINARY, binary, 4 + 4)
+
+ val generic = Map(1 -> "a")
+ checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
}
- testNumericColumnType[BooleanType.type, Boolean](
+ testNativeColumnType[BooleanType.type](
BOOLEAN,
- Array.fill(4)(Random.nextBoolean()),
- ByteBuffer.allocate(32),
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
},
@@ -66,105 +76,42 @@ class ColumnTypeSuite extends FunSuite {
buffer.get() == 1
})
- testNumericColumnType[IntegerType.type, Int](
- INT,
- Array.fill(4)(Random.nextInt()),
- ByteBuffer.allocate(32),
- (_: ByteBuffer).putInt(_),
- (_: ByteBuffer).getInt)
-
- testNumericColumnType[ShortType.type, Short](
- SHORT,
- Array.fill(4)(Random.nextInt(Short.MaxValue).asInstanceOf[Short]),
- ByteBuffer.allocate(32),
- (_: ByteBuffer).putShort(_),
- (_: ByteBuffer).getShort)
-
- testNumericColumnType[LongType.type, Long](
- LONG,
- Array.fill(4)(Random.nextLong()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putLong(_),
- (_: ByteBuffer).getLong)
-
- testNumericColumnType[ByteType.type, Byte](
- BYTE,
- Array.fill(4)(Random.nextInt(Byte.MaxValue).asInstanceOf[Byte]),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).put(_),
- (_: ByteBuffer).get)
-
- testNumericColumnType[DoubleType.type, Double](
- DOUBLE,
- Array.fill(4)(Random.nextDouble()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putDouble(_),
- (_: ByteBuffer).getDouble)
-
- testNumericColumnType[FloatType.type, Float](
- FLOAT,
- Array.fill(4)(Random.nextFloat()),
- ByteBuffer.allocate(64),
- (_: ByteBuffer).putFloat(_),
- (_: ByteBuffer).getFloat)
-
- test("STRING") {
- val buffer = ByteBuffer.allocate(128)
- val seq = Array("hello", "world", "spark", "sql")
-
- seq.map(_.getBytes).foreach { bytes: Array[Byte] =>
- buffer.putInt(bytes.length).put(bytes)
- }
+ testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt)
- buffer.rewind()
- seq.foreach { s =>
- assert(s === STRING.extract(buffer))
- }
+ testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort)
- buffer.rewind()
- seq.foreach(STRING.append(_, buffer))
+ testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong)
- buffer.rewind()
- seq.foreach { s =>
- val length = buffer.getInt
- assert(length === s.getBytes.length)
+ testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get)
+
+ testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)
+
+ testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)
+ testNativeColumnType[StringType.type](
+ STRING,
+ (buffer: ByteBuffer, string: String) => {
+ val bytes = string.getBytes()
+ buffer.putInt(bytes.length).put(string.getBytes)
+ },
+ (buffer: ByteBuffer) => {
+ val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
- assert(s === new String(bytes))
- }
- }
-
- test("BINARY") {
- val buffer = ByteBuffer.allocate(128)
- val seq = Array.fill(4) {
- val bytes = new Array[Byte](4)
- Random.nextBytes(bytes)
- bytes
- }
+ new String(bytes)
+ })
- seq.foreach { bytes =>
+ testColumnType[BinaryType.type, Array[Byte]](
+ BINARY,
+ (buffer: ByteBuffer, bytes: Array[Byte]) => {
buffer.putInt(bytes.length).put(bytes)
- }
-
- buffer.rewind()
- seq.foreach { b =>
- assert(b === BINARY.extract(buffer))
- }
-
- buffer.rewind()
- seq.foreach(BINARY.append(_, buffer))
-
- buffer.rewind()
- seq.foreach { b =>
- val length = buffer.getInt
- assert(length === b.length)
-
+ },
+ (buffer: ByteBuffer) => {
+ val length = buffer.getInt()
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
- assert(b === bytes)
- }
- }
+ bytes
+ })
test("GENERIC") {
val buffer = ByteBuffer.allocate(512)
@@ -177,43 +124,58 @@ class ColumnTypeSuite extends FunSuite {
val length = buffer.getInt()
assert(length === serializedObj.length)
- val bytes = new Array[Byte](length)
- buffer.get(bytes, 0, length)
- assert(obj === SparkSqlSerializer.deserialize(bytes))
+ expectResult(obj, "Deserialized object didn't equal to the original object") {
+ val bytes = new Array[Byte](length)
+ buffer.get(bytes, 0, length)
+ SparkSqlSerializer.deserialize(bytes)
+ }
buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)
- buffer.rewind()
- assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer)))
+ expectResult(obj, "Deserialized object didn't equal to the original object") {
+ buffer.rewind()
+ SparkSqlSerializer.deserialize(GENERIC.extract(buffer))
+ }
+ }
+
+ def testNativeColumnType[T <: NativeType](
+ columnType: NativeColumnType[T],
+ putter: (ByteBuffer, T#JvmType) => Unit,
+ getter: (ByteBuffer) => T#JvmType) {
+
+ testColumnType[T, T#JvmType](columnType, putter, getter)
}
- def testNumericColumnType[T <: DataType, JvmType](
+ def testColumnType[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
- seq: Seq[JvmType],
- buffer: ByteBuffer,
putter: (ByteBuffer, JvmType) => Unit,
getter: (ByteBuffer) => JvmType) {
- val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
+ val seq = (0 until 4).map(_ => makeRandomValue(columnType))
- test(s"$columnTypeName.extract") {
+ test(s"$columnType.extract") {
buffer.rewind()
seq.foreach(putter(buffer, _))
buffer.rewind()
- seq.foreach { i =>
- assert(i === columnType.extract(buffer))
+ seq.foreach { expected =>
+ assert(
+ expected === columnType.extract(buffer),
+ "Extracted value didn't equal to the original one")
}
}
- test(s"$columnTypeName.append") {
+ test(s"$columnType.append") {
buffer.rewind()
seq.foreach(columnType.append(_, buffer))
buffer.rewind()
- seq.foreach { i =>
- assert(i === getter(buffer))
+ seq.foreach { expected =>
+ assert(
+ expected === getter(buffer),
+ "Extracted value didn't equal to the original one")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
index 928851a385d41..70b2e851737f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.columnar
+import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.sql.execution.SparkLogicalPlan
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{TestData, DslQuerySuite}
-class ColumnarQuerySuite extends DslQuerySuite {
+class ColumnarQuerySuite extends QueryTest {
import TestData._
import TestSQLContext._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala
deleted file mode 100644
index ddcdede8d1a4a..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.columnar
-
-import scala.util.Random
-
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-
-// TODO Enrich test data
-object ColumnarTestData {
- object GenericMutableRow {
- def apply(values: Any*) = {
- val row = new GenericMutableRow(values.length)
- row.indices.foreach { i =>
- row(i) = values(i)
- }
- row
- }
- }
-
- def randomBytes(length: Int) = {
- val bytes = new Array[Byte](length)
- Random.nextBytes(bytes)
- bytes
- }
-
- val nonNullRandomRow = GenericMutableRow(
- Random.nextInt(),
- Random.nextLong(),
- Random.nextFloat(),
- Random.nextDouble(),
- Random.nextBoolean(),
- Random.nextInt(Byte.MaxValue).asInstanceOf[Byte],
- Random.nextInt(Short.MaxValue).asInstanceOf[Short],
- Random.nextString(Random.nextInt(64)),
- randomBytes(Random.nextInt(64)),
- Map(Random.nextInt() -> Random.nextString(4)))
-
- val nullRow = GenericMutableRow(Seq.fill(10)(null): _*)
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
new file mode 100644
index 0000000000000..04bdc43d95328
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar
+
+import scala.collection.immutable.HashSet
+import scala.util.Random
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.{DataType, NativeType}
+
+object ColumnarTestUtils {
+ def makeNullRow(length: Int) = {
+ val row = new GenericMutableRow(length)
+ (0 until length).foreach(row.setNullAt)
+ row
+ }
+
+ def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = {
+ def randomBytes(length: Int) = {
+ val bytes = new Array[Byte](length)
+ Random.nextBytes(bytes)
+ bytes
+ }
+
+ (columnType match {
+ case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
+ case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
+ case INT => Random.nextInt()
+ case LONG => Random.nextLong()
+ case FLOAT => Random.nextFloat()
+ case DOUBLE => Random.nextDouble()
+ case STRING => Random.nextString(Random.nextInt(32))
+ case BOOLEAN => Random.nextBoolean()
+ case BINARY => randomBytes(Random.nextInt(32))
+ case _ =>
+ // Using a random one-element map instead of an arbitrary object
+ Map(Random.nextInt() -> Random.nextString(Random.nextInt(32)))
+ }).asInstanceOf[JvmType]
+ }
+
+ def makeRandomValues(
+ head: ColumnType[_ <: DataType, _],
+ tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail)
+
+ def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = {
+ columnTypes.map(makeRandomValue(_))
+ }
+
+ def makeUniqueRandomValues[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType],
+ count: Int): Seq[JvmType] = {
+
+ Iterator.iterate(HashSet.empty[JvmType]) { set =>
+ set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next()
+ }.drop(count).next().toSeq
+ }
+
+ def makeRandomRow(
+ head: ColumnType[_ <: DataType, _],
+ tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail)
+
+ def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = {
+ val row = new GenericMutableRow(columnTypes.length)
+ makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) =>
+ row(index) = value
+ }
+ row
+ }
+
+ def makeUniqueValuesAndSingleValueRows[T <: NativeType](
+ columnType: NativeColumnType[T],
+ count: Int) = {
+
+ val values = makeUniqueRandomValues(columnType, count)
+ val rows = values.map { value =>
+ val row = new GenericMutableRow(1)
+ row(0) = value
+ row
+ }
+
+ (values, rows)
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index d413d483f4e7e..4a21eb6201a69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -17,12 +17,29 @@
package org.apache.spark.sql.columnar
+import java.nio.ByteBuffer
+
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.types.DataType
+
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.types.DataType
+
+class TestNullableColumnAccessor[T <: DataType, JvmType](
+ buffer: ByteBuffer,
+ columnType: ColumnType[T, JvmType])
+ extends BasicColumnAccessor(buffer, columnType)
+ with NullableColumnAccessor
+
+object TestNullableColumnAccessor {
+ def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = {
+ // Skips the column type ID
+ buffer.getInt()
+ new TestNullableColumnAccessor(buffer, columnType)
+ }
+}
class NullableColumnAccessorSuite extends FunSuite {
- import ColumnarTestData._
+ import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnAccessor(_)
@@ -30,30 +47,32 @@ class NullableColumnAccessorSuite extends FunSuite {
def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+ val nullRow = makeNullRow(1)
- test(s"$typeName accessor: empty column") {
- val builder = ColumnBuilder(columnType.typeId, 4)
- val accessor = ColumnAccessor(builder.build())
+ test(s"Nullable $typeName column accessor: empty column") {
+ val builder = TestNullableColumnBuilder(columnType)
+ val accessor = TestNullableColumnAccessor(builder.build(), columnType)
assert(!accessor.hasNext)
}
- test(s"$typeName accessor: access null values") {
- val builder = ColumnBuilder(columnType.typeId, 4)
+ test(s"Nullable $typeName column accessor: access null values") {
+ val builder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
(0 until 4).foreach { _ =>
- builder.appendFrom(nonNullRandomRow, columnType.typeId)
- builder.appendFrom(nullRow, columnType.typeId)
+ builder.appendFrom(randomRow, 0)
+ builder.appendFrom(nullRow, 0)
}
- val accessor = ColumnAccessor(builder.build())
+ val accessor = TestNullableColumnAccessor(builder.build(), columnType)
val row = new GenericMutableRow(1)
(0 until 4).foreach { _ =>
accessor.extractTo(row, 0)
- assert(row(0) === nonNullRandomRow(columnType.typeId))
+ assert(row(0) === randomRow(0))
accessor.extractTo(row, 0)
- assert(row(0) === null)
+ assert(row.isNullAt(0))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index 5222a47e1ab87..d9d1e1bfddb75 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -19,63 +19,71 @@ package org.apache.spark.sql.columnar
import org.scalatest.FunSuite
-import org.apache.spark.sql.catalyst.types.DataType
+import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution.SparkSqlSerializer
+class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType])
+ extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType)
+ with NullableColumnBuilder
+
+object TestNullableColumnBuilder {
+ def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = {
+ val builder = new TestNullableColumnBuilder(columnType)
+ builder.initialize(initialSize)
+ builder
+ }
+}
+
class NullableColumnBuilderSuite extends FunSuite {
- import ColumnarTestData._
+ import ColumnarTestUtils._
Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach {
testNullableColumnBuilder(_)
}
def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
- val columnBuilder = ColumnBuilder(columnType.typeId)
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
test(s"$typeName column builder: empty column") {
- columnBuilder.initialize(4)
-
+ val columnBuilder = TestNullableColumnBuilder(columnType)
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt === 0)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(0, "Wrong null count")(buffer.getInt())
assert(!buffer.hasRemaining)
}
test(s"$typeName column builder: buffer size auto growth") {
- columnBuilder.initialize(4)
+ val columnBuilder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
- (0 until 4) foreach { _ =>
- columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
+ (0 until 4).foreach { _ =>
+ columnBuilder.appendFrom(randomRow, 0)
}
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt() === 0)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(0, "Wrong null count")(buffer.getInt())
}
test(s"$typeName column builder: null values") {
- columnBuilder.initialize(4)
+ val columnBuilder = TestNullableColumnBuilder(columnType)
+ val randomRow = makeRandomRow(columnType)
+ val nullRow = makeNullRow(1)
- (0 until 4) foreach { _ =>
- columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)
- columnBuilder.appendFrom(nullRow, columnType.typeId)
+ (0 until 4).foreach { _ =>
+ columnBuilder.appendFrom(randomRow, 0)
+ columnBuilder.appendFrom(nullRow, 0)
}
val buffer = columnBuilder.build()
- // For column type ID
- assert(buffer.getInt() === columnType.typeId)
- // For null count
- assert(buffer.getInt() === 4)
+ expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt())
+ expectResult(4, "Wrong null count")(buffer.getInt())
+
// For null positions
- (1 to 7 by 2).foreach(i => assert(buffer.getInt() === i))
+ (1 to 7 by 2).foreach(expectResult(_, "Wrong null position")(buffer.getInt()))
// For non-null values
(0 until 4).foreach { _ =>
@@ -84,7 +92,8 @@ class NullableColumnBuilderSuite extends FunSuite {
} else {
columnType.extract(buffer)
}
- assert(actual === nonNullRandomRow(columnType.typeId))
+
+ assert(actual === randomRow(0), "Extracted value didn't equal to the original one")
}
assert(!buffer.hasRemaining)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
new file mode 100644
index 0000000000000..184691ab5b46a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import java.nio.ByteBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
+class DictionaryEncodingSuite extends FunSuite {
+ testDictionaryEncoding(new IntColumnStats, INT)
+ testDictionaryEncoding(new LongColumnStats, LONG)
+ testDictionaryEncoding(new StringColumnStats, STRING)
+
+ def testDictionaryEncoding[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T]) {
+
+ val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+
+ def buildDictionary(buffer: ByteBuffer) = {
+ (0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap
+ }
+
+ test(s"$DictionaryEncoding with $typeName: simple case") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 extra bytes for dictionary size
+ val dictionarySize = 4 + values.map(columnType.actualSize).sum
+ // 4 `Short`s, 2 bytes each
+ val compressedSize = dictionarySize + 2 * 4
+ // 4 extra bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ val dictionary = buildDictionary(buffer)
+ Array[Short](0, 1).foreach { i =>
+ expectResult(i, "Wrong dictionary entry")(dictionary(values(i)))
+ }
+
+ Array[Short](0, 1, 0, 1).foreach {
+ expectResult(_, "Wrong column element value")(buffer.getShort())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new DictionaryEncoding.Decoder[T](buffer, columnType)
+
+ Array[Short](0, 1, 0, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+ }
+
+ test(s"$DictionaryEncoding: overflow") {
+ val builder = TestCompressibleColumnBuilder(new IntColumnStats, INT, DictionaryEncoding)
+ builder.initialize(0)
+
+ (0 to Short.MaxValue).foreach { n =>
+ val row = new GenericMutableRow(1)
+ row.setInt(0, n)
+ builder.appendFrom(row, 0)
+ }
+
+ withClue("Dictionary overflowed, encoding should fail") {
+ intercept[Throwable] {
+ builder.build()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
new file mode 100644
index 0000000000000..2089ad120d4f2
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+import org.apache.spark.sql.columnar.ColumnarTestUtils._
+
+class RunLengthEncodingSuite extends FunSuite {
+ testRunLengthEncoding(new BooleanColumnStats, BOOLEAN)
+ testRunLengthEncoding(new ByteColumnStats, BYTE)
+ testRunLengthEncoding(new ShortColumnStats, SHORT)
+ testRunLengthEncoding(new IntColumnStats, INT)
+ testRunLengthEncoding(new LongColumnStats, LONG)
+ testRunLengthEncoding(new StringColumnStats, STRING)
+
+ def testRunLengthEncoding[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T]) {
+
+ val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
+
+ test(s"$RunLengthEncoding with $typeName: simple case") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 extra bytes each run for run length
+ val compressedSize = values.map(columnType.actualSize(_) + 4).sum
+ // 4 extra bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
+ expectResult(2, "Wrong run length")(buffer.getInt())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
+
+ Array(0, 0, 1, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+
+ test(s"$RunLengthEncoding with $typeName: run length == 1") {
+ // -------------
+ // Tests encoder
+ // -------------
+
+ val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding)
+ val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2)
+
+ builder.initialize(0)
+ builder.appendFrom(rows(0), 0)
+ builder.appendFrom(rows(1), 0)
+
+ val buffer = builder.build()
+ val headerSize = CompressionScheme.columnHeaderSize(buffer)
+ // 4 bytes each run for run length
+ val compressedSize = values.map(columnType.actualSize(_) + 4).sum
+ // 4 bytes for compression scheme type ID
+ expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity)
+
+ // Skips column header
+ buffer.position(headerSize)
+ expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt())
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong column element value")(columnType.extract(buffer))
+ expectResult(1, "Wrong run length")(buffer.getInt())
+ }
+
+ // -------------
+ // Tests decoder
+ // -------------
+
+ // Rewinds, skips column header and 4 more bytes for compression scheme ID
+ buffer.rewind().position(headerSize + 4)
+
+ val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType)
+
+ Array(0, 1).foreach { i =>
+ expectResult(values(i), "Wrong decoded value")(decoder.next())
+ }
+
+ assert(!decoder.hasNext)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
new file mode 100644
index 0000000000000..e0ec812863dcf
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.columnar.compression
+
+import org.apache.spark.sql.catalyst.types.NativeType
+import org.apache.spark.sql.columnar._
+
+class TestCompressibleColumnBuilder[T <: NativeType](
+ override val columnStats: NativeColumnStats[T],
+ override val columnType: NativeColumnType[T],
+ override val schemes: Seq[CompressionScheme])
+ extends NativeColumnBuilder(columnStats, columnType)
+ with NullableColumnBuilder
+ with CompressibleColumnBuilder[T] {
+
+ override protected def isWorthCompressing(encoder: Encoder) = true
+}
+
+object TestCompressibleColumnBuilder {
+ def apply[T <: NativeType](
+ columnStats: NativeColumnStats[T],
+ columnType: NativeColumnType[T],
+ scheme: CompressionScheme) = {
+
+ new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme))
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
index ca5c8b8eb63dc..e55648b8ed15a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala
@@ -39,9 +39,9 @@ case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generato
val Seq(nameAttr, ageAttr) = input
- override def apply(input: Row): TraversableOnce[Row] = {
- val name = nameAttr.apply(input)
- val age = ageAttr.apply(input).asInstanceOf[Int]
+ override def eval(input: Row): TraversableOnce[Row] = {
+ val name = nameAttr.eval(input)
+ val age = ageAttr.eval(input).asInstanceOf[Int]
Iterator(
new GenericRow(Array[Any](s"$name is $age years old")),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 71caa709afca6..fc68d6c5620d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -19,28 +19,67 @@ package org.apache.spark.sql.parquet
import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.hadoop.mapreduce.Job
+
import parquet.hadoop.ParquetFileWriter
-import parquet.hadoop.util.ContextUtil
import parquet.schema.MessageTypeParser
+import parquet.hadoop.util.ContextUtil
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType}
+import org.apache.spark.sql.{parquet, SchemaRDD}
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import scala.Tuple2
+
+// Implicits
+import org.apache.spark.sql.test.TestSQLContext._
+
+case class TestRDDEntry(key: Int, value: String)
class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll {
+
+ var testRDD: SchemaRDD = null
+
override def beforeAll() {
ParquetTestData.writeFile()
+ testRDD = parquetFile(ParquetTestData.testDir.toString)
+ testRDD.registerAsTable("testsource")
}
override def afterAll() {
- ParquetTestData.testFile.delete()
+ Utils.deleteRecursively(ParquetTestData.testDir)
+ // here we should also unregister the table??
+ }
+
+ test("self-join parquet files") {
+ val x = ParquetTestData.testData.as('x)
+ val y = ParquetTestData.testData.as('y)
+ val query = x.join(y).where("x.myint".attr === "y.myint".attr)
+
+ // Check to make sure that the attributes from either side of the join have unique expression
+ // ids.
+ query.queryExecution.analyzed.output.filter(_.name == "myint") match {
+ case Seq(i1, i2) if(i1.exprId == i2.exprId) =>
+ fail(s"Duplicate expression IDs found in query plan: $query")
+ case Seq(_, _) => // All good
+ }
+
+ val result = query.collect()
+ assert(result.size === 9, "self-join result has incorrect size")
+ assert(result(0).size === 12, "result row has incorrect size")
+ result.zipWithIndex.foreach {
+ case (row, index) => row.zipWithIndex.foreach {
+ case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column")
+ }
+ }
}
test("Import of simple Parquet file") {
- val result = getRDD(ParquetTestData.testData).collect()
+ val result = parquetFile(ParquetTestData.testDir.toString).collect()
assert(result.size === 15)
result.zipWithIndex.foreach {
case (row, index) => {
@@ -106,20 +145,82 @@ class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll {
fs.delete(path, true)
}
+ test("Creating case class RDD table") {
+ TestSQLContext.sparkContext.parallelize((1 to 100))
+ .map(i => TestRDDEntry(i, s"val_$i"))
+ .registerAsTable("tmp")
+ val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0))
+ var counter = 1
+ rdd.foreach {
+ // '===' does not like string comparison?
+ row: Row => {
+ assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter")
+ counter = counter + 1
+ }
+ }
+ }
+
+ test("Saving case class RDD table to file and reading it back in") {
+ val file = getTempFilePath("parquet")
+ val path = file.toString
+ val rdd = TestSQLContext.sparkContext.parallelize((1 to 100))
+ .map(i => TestRDDEntry(i, s"val_$i"))
+ rdd.saveAsParquetFile(path)
+ val readFile = parquetFile(path)
+ readFile.registerAsTable("tmpx")
+ val rdd_copy = sql("SELECT * FROM tmpx").collect()
+ val rdd_orig = rdd.collect()
+ for(i <- 0 to 99) {
+ assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i")
+ assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value in line $i")
+ }
+ Utils.deleteRecursively(file)
+ assert(true)
+ }
+
+ test("insert (overwrite) via Scala API (new SchemaRDD)") {
+ val dirname = Utils.createTempDir()
+ val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100))
+ .map(i => TestRDDEntry(i, s"val_$i"))
+ source_rdd.registerAsTable("source")
+ val dest_rdd = createParquetFile(dirname.toString, ("key", IntegerType), ("value", StringType))
+ dest_rdd.registerAsTable("dest")
+ sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect()
+ val rdd_copy1 = sql("SELECT * FROM dest").collect()
+ assert(rdd_copy1.size === 100)
+ assert(rdd_copy1(0).apply(0) === 1)
+ assert(rdd_copy1(0).apply(1) === "val_1")
+ sql("INSERT INTO dest SELECT * FROM source").collect()
+ val rdd_copy2 = sql("SELECT * FROM dest").collect()
+ assert(rdd_copy2.size === 200)
+ Utils.deleteRecursively(dirname)
+ }
+
+ test("insert (appending) to same table via Scala API") {
+ sql("INSERT INTO testsource SELECT * FROM testsource").collect()
+ val double_rdd = sql("SELECT * FROM testsource").collect()
+ assert(double_rdd != null)
+ assert(double_rdd.size === 30)
+ for(i <- (0 to 14)) {
+ assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match")
+ }
+ // let's restore the original test data
+ Utils.deleteRecursively(ParquetTestData.testDir)
+ ParquetTestData.writeFile()
+ }
+
/**
- * Computes the given [[ParquetRelation]] and returns its RDD.
+ * Creates an empty SchemaRDD backed by a ParquetRelation.
*
- * @param parquetRelation The Parquet relation.
- * @return An RDD of Rows.
+ * TODO: since this is so experimental it is better to have it here and not
+ * in SQLContext. Also note that when creating new AttributeReferences
+ * one needs to take care not to create duplicate Attribute ID's.
*/
- private def getRDD(parquetRelation: ParquetRelation): RDD[Row] = {
- val scanner = new ParquetTableScan(
- parquetRelation.output,
- parquetRelation,
- None)(TestSQLContext.sparkContext)
- scanner
- .execute
- .map(_.copy())
+ private def createParquetFile(path: String, schema: (Tuple2[String, DataType])*): SchemaRDD = {
+ val attributes = schema.map(t => new AttributeReference(t._1, t._2)())
+ new SchemaRDD(
+ TestSQLContext,
+ parquet.ParquetRelation.createEmpty(path, attributes, sparkContext.hadoopConfiguration))
}
}
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 7b5ea98f27ff5..a662da76ce25a 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -30,6 +30,17 @@
jar
Spark Project Hive
http://spark.apache.org/
+
+
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
@@ -52,6 +63,10 @@
hive-exec
${hive.version}
+
+ org.codehaus.jackson
+ jackson-mapper-asl
+
org.apache.hive
hive-serde
@@ -76,6 +91,30 @@
org.scalatest
scalatest-maven-plugin
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+ 2.4
+
+
+ copy-dependencies
+ package
+
+ copy-dependencies
+
+
+
+ ${basedir}/../../lib_managed/jars
+ false
+ false
+ true
+ org.datanucleus
+
+
+
+
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index fc5057b73fe24..353458432b210 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -67,10 +67,24 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
class HiveContext(sc: SparkContext) extends SQLContext(sc) {
self =>
- override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql)
- override def executePlan(plan: LogicalPlan): this.QueryExecution =
+ override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
+ /**
+ * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD.
+ */
+ def hiveql(hqlQuery: String): SchemaRDD = {
+ val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery))
+ // We force query optimization to happen right away instead of letting it happen lazily like
+ // when using the query DSL. This is so DDL commands behave as expected. This is only
+ // generates the RDD lineage for DML queries, but do not perform any execution.
+ result.queryExecution.toRdd
+ result
+ }
+
+ /** An alias for `hiveql`. */
+ def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery)
+
// Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur.
@transient
protected val outputBuffer = new java.io.OutputStream {
@@ -108,7 +122,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
- override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
+ override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
override def lookupRelation(
databaseName: Option[String],
tableName: String,
@@ -120,7 +134,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/* An analyzer that uses the Hive metastore. */
@transient
- override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
+ override protected[sql] lazy val analyzer =
+ new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false)
/**
* Runs the specified SQL query using Hive.
@@ -188,13 +203,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
val hiveContext = self
override val strategies: Seq[Strategy] = Seq(
- TopK,
+ TakeOrdered,
ParquetOperations,
HiveTableScans,
DataSinks,
Scripts,
PartialAggregation,
- SparkEquiInnerJoin,
+ HashJoin,
BasicOperators,
CartesianProduct,
BroadcastNestedLoopJoin
@@ -202,14 +217,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
@transient
- override val planner = hivePlanner
+ override protected[sql] val planner = hivePlanner
@transient
protected lazy val emptyResult =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
/** Extends QueryExecution with hive specific features. */
- abstract class QueryExecution extends super.QueryExecution {
+ protected[sql] abstract class QueryExecution extends super.QueryExecution {
// TODO: Create mixin for the analyzer instead of overriding things here.
override lazy val optimizedPlan =
optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
@@ -282,5 +297,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq
asString
}
+
+ override def simpleString: String =
+ logical match {
+ case _: NativeCommand => ""
+ case _ => executedPlan.toString
+ }
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 4f8353666a12b..fc053c56c052d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -141,6 +141,15 @@ class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging {
*/
override def registerTable(
databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ???
+
+ /**
+ * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore.
+ * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]].
+ */
+ override def unregisterTable(
+ databaseName: Option[String], tableName: String): Unit = ???
+
+ override def unregisterAllTables() = {}
}
object HiveMetastoreTypes extends RegexParsers {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f4b61381f9a27..4dac25b3f60e4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -300,14 +300,17 @@ object HiveQl {
}
protected def nodeToDataType(node: Node): DataType = node match {
- case Token("TOK_BIGINT", Nil) => IntegerType
+ case Token("TOK_DECIMAL", Nil) => DecimalType
+ case Token("TOK_BIGINT", Nil) => LongType
case Token("TOK_INT", Nil) => IntegerType
- case Token("TOK_TINYINT", Nil) => IntegerType
- case Token("TOK_SMALLINT", Nil) => IntegerType
+ case Token("TOK_TINYINT", Nil) => ByteType
+ case Token("TOK_SMALLINT", Nil) => ShortType
case Token("TOK_BOOLEAN", Nil) => BooleanType
case Token("TOK_STRING", Nil) => StringType
case Token("TOK_FLOAT", Nil) => FloatType
- case Token("TOK_DOUBLE", Nil) => FloatType
+ case Token("TOK_DOUBLE", Nil) => DoubleType
+ case Token("TOK_TIMESTAMP", Nil) => TimestampType
+ case Token("TOK_BINARY", Nil) => BinaryType
case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
case Token("TOK_STRUCT",
Token("TOK_TABCOLLIST", fields) :: Nil) =>
@@ -529,7 +532,7 @@ object HiveQl {
val withLimit =
limitClause.map(l => nodeToExpr(l.getChildren.head))
- .map(StopAfter(_, withSort))
+ .map(Limit(_, withSort))
.getOrElse(withSort)
// TOK_INSERT_INTO means to add files to the table.
@@ -602,7 +605,7 @@ object HiveQl {
case Token("TOK_TABLESPLITSAMPLE",
Token("TOK_ROWCOUNT", Nil) ::
Token(count, Nil) :: Nil) =>
- StopAfter(Literal(count.toInt), relation)
+ Limit(Literal(count.toInt), relation)
case Token("TOK_TABLESPLITSAMPLE",
Token("TOK_PERCENT", Nil) ::
Token(fraction, Nil) :: Nil) =>
@@ -662,7 +665,7 @@ object HiveQl {
// worth the number of hacks that will be required to implement it. Namely, we need to add
// some sort of mapped star expansion that would expand all child output row to be similarly
// named output expressions where some aggregate expression has been applied (i.e. First).
- ??? /// Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult)
+ ??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult)
case Token(allJoinTokens(joinToken),
relation1 ::
@@ -829,6 +832,8 @@ object HiveQl {
Cast(nodeToExpr(arg), BooleanType)
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType)
+ case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), TimestampType)
/* Arithmetic */
case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
@@ -847,12 +852,9 @@ object HiveQl {
case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right))
case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right))
case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right))
- case Token("LIKE", left :: right:: Nil) =>
- UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
- case Token("RLIKE", left :: right:: Nil) =>
- UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right)))
- case Token("REGEXP", left :: right:: Nil) =>
- UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right)))
+ case Token("LIKE", left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right))
+ case Token("RLIKE", left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
+ case Token("REGEXP", left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>
IsNotNull(nodeToExpr(child))
case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index ca5311344615f..0da5eb754cb3f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -94,7 +94,7 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon
val tablePath = hiveTable.getPath
val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt)
- //logDebug("Table input: %s".format(tablePath))
+ // logDebug("Table input: %s".format(tablePath))
val ifc = hiveTable.getInputFormatClass
.asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]]
val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index bc3447b9d802d..2fea9702954d7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -110,10 +110,10 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
val describedTable = "DESCRIBE (\\w+)".r
- class SqlQueryExecution(sql: String) extends this.QueryExecution {
- lazy val logical = HiveQl.parseSql(sql)
- def hiveExec() = runSqlHive(sql)
- override def toString = sql + "\n" + super.toString
+ protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution {
+ lazy val logical = HiveQl.parseSql(hql)
+ def hiveExec() = runSqlHive(hql)
+ override def toString = hql + "\n" + super.toString
}
/**
@@ -140,8 +140,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
case class TestTable(name: String, commands: (()=>Unit)*)
- implicit class SqlCmd(sql: String) {
- def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit
+ protected[hive] implicit class SqlCmd(sql: String) {
+ def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit
}
/**
@@ -313,6 +313,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
catalog.client.dropDatabase(db, true, false, true)
}
+ catalog.unregisterAllTables()
+
FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName =>
FunctionRegistry.unregisterTemporaryUDF(udfName)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala
new file mode 100644
index 0000000000000..6df76fa825101
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.api.java
+
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.sql.api.java.{JavaSQLContext, JavaSchemaRDD}
+import org.apache.spark.sql.hive.{HiveContext, HiveQl}
+
+/**
+ * The entry point for executing Spark SQL queries from a Java program.
+ */
+class JavaHiveContext(sparkContext: JavaSparkContext) extends JavaSQLContext(sparkContext) {
+
+ override val sqlContext = new HiveContext(sparkContext)
+
+ /**
+ * Executes a query expressed in HiveQL, returning the result as a JavaSchemaRDD.
+ */
+ def hql(hqlQuery: String): JavaSchemaRDD = {
+ val result = new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery))
+ // We force query optimization to happen right away instead of letting it happen lazily like
+ // when using the query DSL. This is so DDL commands behave as expected. This is only
+ // generates the RDD lineage for DML queries, but do not perform any execution.
+ result.queryExecution.toRdd
+ result
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
index e2d9d8de2572a..821fb22112f87 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
@@ -106,7 +106,7 @@ case class HiveTableScan(
}
private def castFromString(value: String, dataType: DataType) = {
- Cast(Literal(value), dataType).apply(null)
+ Cast(Literal(value), dataType).eval(null)
}
@transient
@@ -134,7 +134,7 @@ case class HiveTableScan(
// Only partitioned values are needed here, since the predicate has already been bound to
// partition key attribute references.
val row = new GenericRow(castedValues.toArray)
- shouldKeep.apply(row).asInstanceOf[Boolean]
+ shouldKeep.eval(row).asInstanceOf[Boolean]
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 44901db3f963b..f9b437d435eba 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -190,8 +190,8 @@ case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUd
}
// TODO: Finish input output types.
- override def apply(input: Row): Any = {
- val evaluatedChildren = children.map(_.apply(input))
+ override def eval(input: Row): Any = {
+ val evaluatedChildren = children.map(_.eval(input))
// Wrap the function arguments in the expected types.
val args = evaluatedChildren.zip(wrappers).map {
case (arg, wrapper) => wrapper(arg)
@@ -216,12 +216,12 @@ case class HiveGenericUdf(
val dataType: DataType = inspectorToDataType(returnInspector)
- override def apply(input: Row): Any = {
+ override def eval(input: Row): Any = {
returnInspector // Make sure initialized.
val args = children.map { v =>
new DeferredObject {
override def prepare(i: Int) = {}
- override def get(): AnyRef = wrap(v.apply(input))
+ override def get(): AnyRef = wrap(v.eval(input))
}
}.toArray
unwrap(function.evaluate(args))
@@ -337,13 +337,16 @@ case class HiveGenericUdaf(
type UDFType = AbstractGenericUDAFResolver
+ @transient
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
+ @transient
protected lazy val objectInspector = {
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray)
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
}
+ @transient
protected lazy val inspectors = children.map(_.dataType).map(toInspector)
def dataType: DataType = inspectorToDataType(objectInspector)
@@ -403,7 +406,7 @@ case class HiveGenericUdtf(
}
}
- override def apply(input: Row): TraversableOnce[Row] = {
+ override def eval(input: Row): TraversableOnce[Row] = {
outputInspectors // Make sure initialized.
val inputProjection = new Projection(children)
@@ -457,7 +460,7 @@ case class HiveUdafFunction(
private val buffer =
function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
- override def apply(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector)
+ override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector)
@transient
val inputProjection = new Projection(exprs)
diff --git a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d
deleted file mode 100644
index 5f4de85940513..0000000000000
--- a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d
+++ /dev/null
@@ -1 +0,0 @@
-0 val_0
\ No newline at end of file
diff --git a/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d
new file mode 100644
index 0000000000000..016f64cc26f2a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d
@@ -0,0 +1 @@
+0 val_0
diff --git a/sql/hive/src/test/resources/golden/insert1-0-7faa9807151781e4207103aa568e321c b/sql/hive/src/test/resources/golden/insert1-0-7faa9807151781e4207103aa568e321c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-1-91d7b05c9024bff60b55f415cbeacc8b b/sql/hive/src/test/resources/golden/insert1-1-91d7b05c9024bff60b55f415cbeacc8b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-10-64f83491a8fe675ef3a4a9a474ac0439 b/sql/hive/src/test/resources/golden/insert1-10-64f83491a8fe675ef3a4a9a474ac0439
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-11-6f2797b6f81943d3b53b8d247ae8512b b/sql/hive/src/test/resources/golden/insert1-11-6f2797b6f81943d3b53b8d247ae8512b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-12-7a3c0a3f06484c912b9e951d8a2d8ac6 b/sql/hive/src/test/resources/golden/insert1-12-7a3c0a3f06484c912b9e951d8a2d8ac6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-13-42b03f938894fdafc7fff640711a9b2f b/sql/hive/src/test/resources/golden/insert1-13-42b03f938894fdafc7fff640711a9b2f
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-14-e021dfb28597811870c03b3242972927 b/sql/hive/src/test/resources/golden/insert1-14-e021dfb28597811870c03b3242972927
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-15-c7fca497a4580b54a0a13b3b72da5d7c b/sql/hive/src/test/resources/golden/insert1-15-c7fca497a4580b54a0a13b3b72da5d7c
new file mode 100644
index 0000000000000..5be49cad9a8ba
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert1-15-c7fca497a4580b54a0a13b3b72da5d7c
@@ -0,0 +1,2 @@
+db2_insert1
+db2_insert2
diff --git a/sql/hive/src/test/resources/golden/insert1-16-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/insert1-16-7a9e67189d3d4151f23b12c22bde06b5
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-17-5528e36b3b0f5b14313898cc45f9c23a b/sql/hive/src/test/resources/golden/insert1-17-5528e36b3b0f5b14313898cc45f9c23a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-18-16d78fba2d86277bc2f804037cc0a8b4 b/sql/hive/src/test/resources/golden/insert1-18-16d78fba2d86277bc2f804037cc0a8b4
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-19-62518ff6810db9cdd8926702192a206b b/sql/hive/src/test/resources/golden/insert1-19-62518ff6810db9cdd8926702192a206b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-2-3f1de4475930285c3fdbe3a5ccd4e868 b/sql/hive/src/test/resources/golden/insert1-2-3f1de4475930285c3fdbe3a5ccd4e868
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-20-f4dc51ad64bb8662d066a8b9003da3d4 b/sql/hive/src/test/resources/golden/insert1-20-f4dc51ad64bb8662d066a8b9003da3d4
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-21-bb7624250ab556f2d40bfb8d419be487 b/sql/hive/src/test/resources/golden/insert1-21-bb7624250ab556f2d40bfb8d419be487
new file mode 100644
index 0000000000000..1e3637ebc6af2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert1-21-bb7624250ab556f2d40bfb8d419be487
@@ -0,0 +1,2 @@
+db1_insert1
+db1_insert2
diff --git a/sql/hive/src/test/resources/golden/insert1-3-89f8a028e32fae213b575b4df4e26e9c b/sql/hive/src/test/resources/golden/insert1-3-89f8a028e32fae213b575b4df4e26e9c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-4-c7a68c0884785d0f5e62b287eb305d64 b/sql/hive/src/test/resources/golden/insert1-4-c7a68c0884785d0f5e62b287eb305d64
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-5-cb87ee12092fdf05daed82485c32a285 b/sql/hive/src/test/resources/golden/insert1-5-cb87ee12092fdf05daed82485c32a285
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-6-b97ba93a2c9ae671ecfc4fa95c024dda b/sql/hive/src/test/resources/golden/insert1-6-b97ba93a2c9ae671ecfc4fa95c024dda
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-7-a2cd0615b9e79befd9c1842516150a61 b/sql/hive/src/test/resources/golden/insert1-7-a2cd0615b9e79befd9c1842516150a61
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-8-5942e331621fe522fc297844046d2370 b/sql/hive/src/test/resources/golden/insert1-8-5942e331621fe522fc297844046d2370
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert1-9-5c5132707d7a4fb6e6a3de1a6719721a b/sql/hive/src/test/resources/golden/insert1-9-5c5132707d7a4fb6e6a3de1a6719721a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-0-5528e36b3b0f5b14313898cc45f9c23a b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-0-5528e36b3b0f5b14313898cc45f9c23a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-1-deb504f4f70fd7db975950c3c47959ee b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-1-deb504f4f70fd7db975950c3c47959ee
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-10-fda2e4be738186c0938f92d5072df55a b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-10-fda2e4be738186c0938f92d5072df55a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-11-9fb177236623d1b62acff28507033436 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-11-9fb177236623d1b62acff28507033436
new file mode 100644
index 0000000000000..01f2b7063f91b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-11-9fb177236623d1b62acff28507033436
@@ -0,0 +1,5 @@
+98 val_98
+98 val_98
+98 val_98
+97 val_97
+97 val_97
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-12-99d5ad32bb81640cb284312841b60000 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-12-99d5ad32bb81640cb284312841b60000
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-13-9dda06e1aae1860bd19eee97703a8217 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-13-9dda06e1aae1860bd19eee97703a8217
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-14-19daabdd4c0d403c8781967248d09c53 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-14-19daabdd4c0d403c8781967248d09c53
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-15-812006e1f11e005e5029866d1cf004f6 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-15-812006e1f11e005e5029866d1cf004f6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-2-bd042746328158822a25d711ffed18dd b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-2-bd042746328158822a25d711ffed18dd
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-3-b7aaedd7d624af4e48637ff1acabe485 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-3-b7aaedd7d624af4e48637ff1acabe485
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-4-dece2650bf0615e566cd6c84181ce026 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-4-dece2650bf0615e566cd6c84181ce026
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-5-1eb5c694e5a02aa292e24a0849350108 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-5-1eb5c694e5a02aa292e24a0849350108
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-6-ab49e0665a80a6b34dadc96f1d18ce26 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-6-ab49e0665a80a6b34dadc96f1d18ce26
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-7-fda2e4be738186c0938f92d5072df55a b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-7-fda2e4be738186c0938f92d5072df55a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-8-9fb177236623d1b62acff28507033436 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-8-9fb177236623d1b62acff28507033436
new file mode 100644
index 0000000000000..01f2b7063f91b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-8-9fb177236623d1b62acff28507033436
@@ -0,0 +1,5 @@
+98 val_98
+98 val_98
+98 val_98
+97 val_97
+97 val_97
diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-9-ab49e0665a80a6b34dadc96f1d18ce26 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-9-ab49e0665a80a6b34dadc96f1d18ce26
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/load_binary_data-0-491edd0c42ceb79e799ba50555bc8c15 b/sql/hive/src/test/resources/golden/load_binary_data-0-491edd0c42ceb79e799ba50555bc8c15
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/load_binary_data-1-5d72f8449b69df3c08e3f444f09428bc b/sql/hive/src/test/resources/golden/load_binary_data-1-5d72f8449b69df3c08e3f444f09428bc
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/load_binary_data-2-242b1655c7e7325ee9f26552ea8fc25 b/sql/hive/src/test/resources/golden/load_binary_data-2-242b1655c7e7325ee9f26552ea8fc25
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/load_binary_data-3-2a72df8d3e398d0963ef91162ce7d268 b/sql/hive/src/test/resources/golden/load_binary_data-3-2a72df8d3e398d0963ef91162ce7d268
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/partitioned table scan-0-3e8898a13ccef627603f340d1f8bdd80 b/sql/hive/src/test/resources/golden/partitioned table scan-0-3e8898a13ccef627603f340d1f8bdd80
new file mode 100644
index 0000000000000..a3cb00feaca62
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/partitioned table scan-0-3e8898a13ccef627603f340d1f8bdd80
@@ -0,0 +1,2000 @@
+2008-04-08 11 238 val_238
+2008-04-08 11 86 val_86
+2008-04-08 11 311 val_311
+2008-04-08 11 27 val_27
+2008-04-08 11 165 val_165
+2008-04-08 11 409 val_409
+2008-04-08 11 255 val_255
+2008-04-08 11 278 val_278
+2008-04-08 11 98 val_98
+2008-04-08 11 484 val_484
+2008-04-08 11 265 val_265
+2008-04-08 11 193 val_193
+2008-04-08 11 401 val_401
+2008-04-08 11 150 val_150
+2008-04-08 11 273 val_273
+2008-04-08 11 224 val_224
+2008-04-08 11 369 val_369
+2008-04-08 11 66 val_66
+2008-04-08 11 128 val_128
+2008-04-08 11 213 val_213
+2008-04-08 11 146 val_146
+2008-04-08 11 406 val_406
+2008-04-08 11 429 val_429
+2008-04-08 11 374 val_374
+2008-04-08 11 152 val_152
+2008-04-08 11 469 val_469
+2008-04-08 11 145 val_145
+2008-04-08 11 495 val_495
+2008-04-08 11 37 val_37
+2008-04-08 11 327 val_327
+2008-04-08 11 281 val_281
+2008-04-08 11 277 val_277
+2008-04-08 11 209 val_209
+2008-04-08 11 15 val_15
+2008-04-08 11 82 val_82
+2008-04-08 11 403 val_403
+2008-04-08 11 166 val_166
+2008-04-08 11 417 val_417
+2008-04-08 11 430 val_430
+2008-04-08 11 252 val_252
+2008-04-08 11 292 val_292
+2008-04-08 11 219 val_219
+2008-04-08 11 287 val_287
+2008-04-08 11 153 val_153
+2008-04-08 11 193 val_193
+2008-04-08 11 338 val_338
+2008-04-08 11 446 val_446
+2008-04-08 11 459 val_459
+2008-04-08 11 394 val_394
+2008-04-08 11 237 val_237
+2008-04-08 11 482 val_482
+2008-04-08 11 174 val_174
+2008-04-08 11 413 val_413
+2008-04-08 11 494 val_494
+2008-04-08 11 207 val_207
+2008-04-08 11 199 val_199
+2008-04-08 11 466 val_466
+2008-04-08 11 208 val_208
+2008-04-08 11 174 val_174
+2008-04-08 11 399 val_399
+2008-04-08 11 396 val_396
+2008-04-08 11 247 val_247
+2008-04-08 11 417 val_417
+2008-04-08 11 489 val_489
+2008-04-08 11 162 val_162
+2008-04-08 11 377 val_377
+2008-04-08 11 397 val_397
+2008-04-08 11 309 val_309
+2008-04-08 11 365 val_365
+2008-04-08 11 266 val_266
+2008-04-08 11 439 val_439
+2008-04-08 11 342 val_342
+2008-04-08 11 367 val_367
+2008-04-08 11 325 val_325
+2008-04-08 11 167 val_167
+2008-04-08 11 195 val_195
+2008-04-08 11 475 val_475
+2008-04-08 11 17 val_17
+2008-04-08 11 113 val_113
+2008-04-08 11 155 val_155
+2008-04-08 11 203 val_203
+2008-04-08 11 339 val_339
+2008-04-08 11 0 val_0
+2008-04-08 11 455 val_455
+2008-04-08 11 128 val_128
+2008-04-08 11 311 val_311
+2008-04-08 11 316 val_316
+2008-04-08 11 57 val_57
+2008-04-08 11 302 val_302
+2008-04-08 11 205 val_205
+2008-04-08 11 149 val_149
+2008-04-08 11 438 val_438
+2008-04-08 11 345 val_345
+2008-04-08 11 129 val_129
+2008-04-08 11 170 val_170
+2008-04-08 11 20 val_20
+2008-04-08 11 489 val_489
+2008-04-08 11 157 val_157
+2008-04-08 11 378 val_378
+2008-04-08 11 221 val_221
+2008-04-08 11 92 val_92
+2008-04-08 11 111 val_111
+2008-04-08 11 47 val_47
+2008-04-08 11 72 val_72
+2008-04-08 11 4 val_4
+2008-04-08 11 280 val_280
+2008-04-08 11 35 val_35
+2008-04-08 11 427 val_427
+2008-04-08 11 277 val_277
+2008-04-08 11 208 val_208
+2008-04-08 11 356 val_356
+2008-04-08 11 399 val_399
+2008-04-08 11 169 val_169
+2008-04-08 11 382 val_382
+2008-04-08 11 498 val_498
+2008-04-08 11 125 val_125
+2008-04-08 11 386 val_386
+2008-04-08 11 437 val_437
+2008-04-08 11 469 val_469
+2008-04-08 11 192 val_192
+2008-04-08 11 286 val_286
+2008-04-08 11 187 val_187
+2008-04-08 11 176 val_176
+2008-04-08 11 54 val_54
+2008-04-08 11 459 val_459
+2008-04-08 11 51 val_51
+2008-04-08 11 138 val_138
+2008-04-08 11 103 val_103
+2008-04-08 11 239 val_239
+2008-04-08 11 213 val_213
+2008-04-08 11 216 val_216
+2008-04-08 11 430 val_430
+2008-04-08 11 278 val_278
+2008-04-08 11 176 val_176
+2008-04-08 11 289 val_289
+2008-04-08 11 221 val_221
+2008-04-08 11 65 val_65
+2008-04-08 11 318 val_318
+2008-04-08 11 332 val_332
+2008-04-08 11 311 val_311
+2008-04-08 11 275 val_275
+2008-04-08 11 137 val_137
+2008-04-08 11 241 val_241
+2008-04-08 11 83 val_83
+2008-04-08 11 333 val_333
+2008-04-08 11 180 val_180
+2008-04-08 11 284 val_284
+2008-04-08 11 12 val_12
+2008-04-08 11 230 val_230
+2008-04-08 11 181 val_181
+2008-04-08 11 67 val_67
+2008-04-08 11 260 val_260
+2008-04-08 11 404 val_404
+2008-04-08 11 384 val_384
+2008-04-08 11 489 val_489
+2008-04-08 11 353 val_353
+2008-04-08 11 373 val_373
+2008-04-08 11 272 val_272
+2008-04-08 11 138 val_138
+2008-04-08 11 217 val_217
+2008-04-08 11 84 val_84
+2008-04-08 11 348 val_348
+2008-04-08 11 466 val_466
+2008-04-08 11 58 val_58
+2008-04-08 11 8 val_8
+2008-04-08 11 411 val_411
+2008-04-08 11 230 val_230
+2008-04-08 11 208 val_208
+2008-04-08 11 348 val_348
+2008-04-08 11 24 val_24
+2008-04-08 11 463 val_463
+2008-04-08 11 431 val_431
+2008-04-08 11 179 val_179
+2008-04-08 11 172 val_172
+2008-04-08 11 42 val_42
+2008-04-08 11 129 val_129
+2008-04-08 11 158 val_158
+2008-04-08 11 119 val_119
+2008-04-08 11 496 val_496
+2008-04-08 11 0 val_0
+2008-04-08 11 322 val_322
+2008-04-08 11 197 val_197
+2008-04-08 11 468 val_468
+2008-04-08 11 393 val_393
+2008-04-08 11 454 val_454
+2008-04-08 11 100 val_100
+2008-04-08 11 298 val_298
+2008-04-08 11 199 val_199
+2008-04-08 11 191 val_191
+2008-04-08 11 418 val_418
+2008-04-08 11 96 val_96
+2008-04-08 11 26 val_26
+2008-04-08 11 165 val_165
+2008-04-08 11 327 val_327
+2008-04-08 11 230 val_230
+2008-04-08 11 205 val_205
+2008-04-08 11 120 val_120
+2008-04-08 11 131 val_131
+2008-04-08 11 51 val_51
+2008-04-08 11 404 val_404
+2008-04-08 11 43 val_43
+2008-04-08 11 436 val_436
+2008-04-08 11 156 val_156
+2008-04-08 11 469 val_469
+2008-04-08 11 468 val_468
+2008-04-08 11 308 val_308
+2008-04-08 11 95 val_95
+2008-04-08 11 196 val_196
+2008-04-08 11 288 val_288
+2008-04-08 11 481 val_481
+2008-04-08 11 457 val_457
+2008-04-08 11 98 val_98
+2008-04-08 11 282 val_282
+2008-04-08 11 197 val_197
+2008-04-08 11 187 val_187
+2008-04-08 11 318 val_318
+2008-04-08 11 318 val_318
+2008-04-08 11 409 val_409
+2008-04-08 11 470 val_470
+2008-04-08 11 137 val_137
+2008-04-08 11 369 val_369
+2008-04-08 11 316 val_316
+2008-04-08 11 169 val_169
+2008-04-08 11 413 val_413
+2008-04-08 11 85 val_85
+2008-04-08 11 77 val_77
+2008-04-08 11 0 val_0
+2008-04-08 11 490 val_490
+2008-04-08 11 87 val_87
+2008-04-08 11 364 val_364
+2008-04-08 11 179 val_179
+2008-04-08 11 118 val_118
+2008-04-08 11 134 val_134
+2008-04-08 11 395 val_395
+2008-04-08 11 282 val_282
+2008-04-08 11 138 val_138
+2008-04-08 11 238 val_238
+2008-04-08 11 419 val_419
+2008-04-08 11 15 val_15
+2008-04-08 11 118 val_118
+2008-04-08 11 72 val_72
+2008-04-08 11 90 val_90
+2008-04-08 11 307 val_307
+2008-04-08 11 19 val_19
+2008-04-08 11 435 val_435
+2008-04-08 11 10 val_10
+2008-04-08 11 277 val_277
+2008-04-08 11 273 val_273
+2008-04-08 11 306 val_306
+2008-04-08 11 224 val_224
+2008-04-08 11 309 val_309
+2008-04-08 11 389 val_389
+2008-04-08 11 327 val_327
+2008-04-08 11 242 val_242
+2008-04-08 11 369 val_369
+2008-04-08 11 392 val_392
+2008-04-08 11 272 val_272
+2008-04-08 11 331 val_331
+2008-04-08 11 401 val_401
+2008-04-08 11 242 val_242
+2008-04-08 11 452 val_452
+2008-04-08 11 177 val_177
+2008-04-08 11 226 val_226
+2008-04-08 11 5 val_5
+2008-04-08 11 497 val_497
+2008-04-08 11 402 val_402
+2008-04-08 11 396 val_396
+2008-04-08 11 317 val_317
+2008-04-08 11 395 val_395
+2008-04-08 11 58 val_58
+2008-04-08 11 35 val_35
+2008-04-08 11 336 val_336
+2008-04-08 11 95 val_95
+2008-04-08 11 11 val_11
+2008-04-08 11 168 val_168
+2008-04-08 11 34 val_34
+2008-04-08 11 229 val_229
+2008-04-08 11 233 val_233
+2008-04-08 11 143 val_143
+2008-04-08 11 472 val_472
+2008-04-08 11 322 val_322
+2008-04-08 11 498 val_498
+2008-04-08 11 160 val_160
+2008-04-08 11 195 val_195
+2008-04-08 11 42 val_42
+2008-04-08 11 321 val_321
+2008-04-08 11 430 val_430
+2008-04-08 11 119 val_119
+2008-04-08 11 489 val_489
+2008-04-08 11 458 val_458
+2008-04-08 11 78 val_78
+2008-04-08 11 76 val_76
+2008-04-08 11 41 val_41
+2008-04-08 11 223 val_223
+2008-04-08 11 492 val_492
+2008-04-08 11 149 val_149
+2008-04-08 11 449 val_449
+2008-04-08 11 218 val_218
+2008-04-08 11 228 val_228
+2008-04-08 11 138 val_138
+2008-04-08 11 453 val_453
+2008-04-08 11 30 val_30
+2008-04-08 11 209 val_209
+2008-04-08 11 64 val_64
+2008-04-08 11 468 val_468
+2008-04-08 11 76 val_76
+2008-04-08 11 74 val_74
+2008-04-08 11 342 val_342
+2008-04-08 11 69 val_69
+2008-04-08 11 230 val_230
+2008-04-08 11 33 val_33
+2008-04-08 11 368 val_368
+2008-04-08 11 103 val_103
+2008-04-08 11 296 val_296
+2008-04-08 11 113 val_113
+2008-04-08 11 216 val_216
+2008-04-08 11 367 val_367
+2008-04-08 11 344 val_344
+2008-04-08 11 167 val_167
+2008-04-08 11 274 val_274
+2008-04-08 11 219 val_219
+2008-04-08 11 239 val_239
+2008-04-08 11 485 val_485
+2008-04-08 11 116 val_116
+2008-04-08 11 223 val_223
+2008-04-08 11 256 val_256
+2008-04-08 11 263 val_263
+2008-04-08 11 70 val_70
+2008-04-08 11 487 val_487
+2008-04-08 11 480 val_480
+2008-04-08 11 401 val_401
+2008-04-08 11 288 val_288
+2008-04-08 11 191 val_191
+2008-04-08 11 5 val_5
+2008-04-08 11 244 val_244
+2008-04-08 11 438 val_438
+2008-04-08 11 128 val_128
+2008-04-08 11 467 val_467
+2008-04-08 11 432 val_432
+2008-04-08 11 202 val_202
+2008-04-08 11 316 val_316
+2008-04-08 11 229 val_229
+2008-04-08 11 469 val_469
+2008-04-08 11 463 val_463
+2008-04-08 11 280 val_280
+2008-04-08 11 2 val_2
+2008-04-08 11 35 val_35
+2008-04-08 11 283 val_283
+2008-04-08 11 331 val_331
+2008-04-08 11 235 val_235
+2008-04-08 11 80 val_80
+2008-04-08 11 44 val_44
+2008-04-08 11 193 val_193
+2008-04-08 11 321 val_321
+2008-04-08 11 335 val_335
+2008-04-08 11 104 val_104
+2008-04-08 11 466 val_466
+2008-04-08 11 366 val_366
+2008-04-08 11 175 val_175
+2008-04-08 11 403 val_403
+2008-04-08 11 483 val_483
+2008-04-08 11 53 val_53
+2008-04-08 11 105 val_105
+2008-04-08 11 257 val_257
+2008-04-08 11 406 val_406
+2008-04-08 11 409 val_409
+2008-04-08 11 190 val_190
+2008-04-08 11 406 val_406
+2008-04-08 11 401 val_401
+2008-04-08 11 114 val_114
+2008-04-08 11 258 val_258
+2008-04-08 11 90 val_90
+2008-04-08 11 203 val_203
+2008-04-08 11 262 val_262
+2008-04-08 11 348 val_348
+2008-04-08 11 424 val_424
+2008-04-08 11 12 val_12
+2008-04-08 11 396 val_396
+2008-04-08 11 201 val_201
+2008-04-08 11 217 val_217
+2008-04-08 11 164 val_164
+2008-04-08 11 431 val_431
+2008-04-08 11 454 val_454
+2008-04-08 11 478 val_478
+2008-04-08 11 298 val_298
+2008-04-08 11 125 val_125
+2008-04-08 11 431 val_431
+2008-04-08 11 164 val_164
+2008-04-08 11 424 val_424
+2008-04-08 11 187 val_187
+2008-04-08 11 382 val_382
+2008-04-08 11 5 val_5
+2008-04-08 11 70 val_70
+2008-04-08 11 397 val_397
+2008-04-08 11 480 val_480
+2008-04-08 11 291 val_291
+2008-04-08 11 24 val_24
+2008-04-08 11 351 val_351
+2008-04-08 11 255 val_255
+2008-04-08 11 104 val_104
+2008-04-08 11 70 val_70
+2008-04-08 11 163 val_163
+2008-04-08 11 438 val_438
+2008-04-08 11 119 val_119
+2008-04-08 11 414 val_414
+2008-04-08 11 200 val_200
+2008-04-08 11 491 val_491
+2008-04-08 11 237 val_237
+2008-04-08 11 439 val_439
+2008-04-08 11 360 val_360
+2008-04-08 11 248 val_248
+2008-04-08 11 479 val_479
+2008-04-08 11 305 val_305
+2008-04-08 11 417 val_417
+2008-04-08 11 199 val_199
+2008-04-08 11 444 val_444
+2008-04-08 11 120 val_120
+2008-04-08 11 429 val_429
+2008-04-08 11 169 val_169
+2008-04-08 11 443 val_443
+2008-04-08 11 323 val_323
+2008-04-08 11 325 val_325
+2008-04-08 11 277 val_277
+2008-04-08 11 230 val_230
+2008-04-08 11 478 val_478
+2008-04-08 11 178 val_178
+2008-04-08 11 468 val_468
+2008-04-08 11 310 val_310
+2008-04-08 11 317 val_317
+2008-04-08 11 333 val_333
+2008-04-08 11 493 val_493
+2008-04-08 11 460 val_460
+2008-04-08 11 207 val_207
+2008-04-08 11 249 val_249
+2008-04-08 11 265 val_265
+2008-04-08 11 480 val_480
+2008-04-08 11 83 val_83
+2008-04-08 11 136 val_136
+2008-04-08 11 353 val_353
+2008-04-08 11 172 val_172
+2008-04-08 11 214 val_214
+2008-04-08 11 462 val_462
+2008-04-08 11 233 val_233
+2008-04-08 11 406 val_406
+2008-04-08 11 133 val_133
+2008-04-08 11 175 val_175
+2008-04-08 11 189 val_189
+2008-04-08 11 454 val_454
+2008-04-08 11 375 val_375
+2008-04-08 11 401 val_401
+2008-04-08 11 421 val_421
+2008-04-08 11 407 val_407
+2008-04-08 11 384 val_384
+2008-04-08 11 256 val_256
+2008-04-08 11 26 val_26
+2008-04-08 11 134 val_134
+2008-04-08 11 67 val_67
+2008-04-08 11 384 val_384
+2008-04-08 11 379 val_379
+2008-04-08 11 18 val_18
+2008-04-08 11 462 val_462
+2008-04-08 11 492 val_492
+2008-04-08 11 100 val_100
+2008-04-08 11 298 val_298
+2008-04-08 11 9 val_9
+2008-04-08 11 341 val_341
+2008-04-08 11 498 val_498
+2008-04-08 11 146 val_146
+2008-04-08 11 458 val_458
+2008-04-08 11 362 val_362
+2008-04-08 11 186 val_186
+2008-04-08 11 285 val_285
+2008-04-08 11 348 val_348
+2008-04-08 11 167 val_167
+2008-04-08 11 18 val_18
+2008-04-08 11 273 val_273
+2008-04-08 11 183 val_183
+2008-04-08 11 281 val_281
+2008-04-08 11 344 val_344
+2008-04-08 11 97 val_97
+2008-04-08 11 469 val_469
+2008-04-08 11 315 val_315
+2008-04-08 11 84 val_84
+2008-04-08 11 28 val_28
+2008-04-08 11 37 val_37
+2008-04-08 11 448 val_448
+2008-04-08 11 152 val_152
+2008-04-08 11 348 val_348
+2008-04-08 11 307 val_307
+2008-04-08 11 194 val_194
+2008-04-08 11 414 val_414
+2008-04-08 11 477 val_477
+2008-04-08 11 222 val_222
+2008-04-08 11 126 val_126
+2008-04-08 11 90 val_90
+2008-04-08 11 169 val_169
+2008-04-08 11 403 val_403
+2008-04-08 11 400 val_400
+2008-04-08 11 200 val_200
+2008-04-08 11 97 val_97
+2008-04-08 12 238 val_238
+2008-04-08 12 86 val_86
+2008-04-08 12 311 val_311
+2008-04-08 12 27 val_27
+2008-04-08 12 165 val_165
+2008-04-08 12 409 val_409
+2008-04-08 12 255 val_255
+2008-04-08 12 278 val_278
+2008-04-08 12 98 val_98
+2008-04-08 12 484 val_484
+2008-04-08 12 265 val_265
+2008-04-08 12 193 val_193
+2008-04-08 12 401 val_401
+2008-04-08 12 150 val_150
+2008-04-08 12 273 val_273
+2008-04-08 12 224 val_224
+2008-04-08 12 369 val_369
+2008-04-08 12 66 val_66
+2008-04-08 12 128 val_128
+2008-04-08 12 213 val_213
+2008-04-08 12 146 val_146
+2008-04-08 12 406 val_406
+2008-04-08 12 429 val_429
+2008-04-08 12 374 val_374
+2008-04-08 12 152 val_152
+2008-04-08 12 469 val_469
+2008-04-08 12 145 val_145
+2008-04-08 12 495 val_495
+2008-04-08 12 37 val_37
+2008-04-08 12 327 val_327
+2008-04-08 12 281 val_281
+2008-04-08 12 277 val_277
+2008-04-08 12 209 val_209
+2008-04-08 12 15 val_15
+2008-04-08 12 82 val_82
+2008-04-08 12 403 val_403
+2008-04-08 12 166 val_166
+2008-04-08 12 417 val_417
+2008-04-08 12 430 val_430
+2008-04-08 12 252 val_252
+2008-04-08 12 292 val_292
+2008-04-08 12 219 val_219
+2008-04-08 12 287 val_287
+2008-04-08 12 153 val_153
+2008-04-08 12 193 val_193
+2008-04-08 12 338 val_338
+2008-04-08 12 446 val_446
+2008-04-08 12 459 val_459
+2008-04-08 12 394 val_394
+2008-04-08 12 237 val_237
+2008-04-08 12 482 val_482
+2008-04-08 12 174 val_174
+2008-04-08 12 413 val_413
+2008-04-08 12 494 val_494
+2008-04-08 12 207 val_207
+2008-04-08 12 199 val_199
+2008-04-08 12 466 val_466
+2008-04-08 12 208 val_208
+2008-04-08 12 174 val_174
+2008-04-08 12 399 val_399
+2008-04-08 12 396 val_396
+2008-04-08 12 247 val_247
+2008-04-08 12 417 val_417
+2008-04-08 12 489 val_489
+2008-04-08 12 162 val_162
+2008-04-08 12 377 val_377
+2008-04-08 12 397 val_397
+2008-04-08 12 309 val_309
+2008-04-08 12 365 val_365
+2008-04-08 12 266 val_266
+2008-04-08 12 439 val_439
+2008-04-08 12 342 val_342
+2008-04-08 12 367 val_367
+2008-04-08 12 325 val_325
+2008-04-08 12 167 val_167
+2008-04-08 12 195 val_195
+2008-04-08 12 475 val_475
+2008-04-08 12 17 val_17
+2008-04-08 12 113 val_113
+2008-04-08 12 155 val_155
+2008-04-08 12 203 val_203
+2008-04-08 12 339 val_339
+2008-04-08 12 0 val_0
+2008-04-08 12 455 val_455
+2008-04-08 12 128 val_128
+2008-04-08 12 311 val_311
+2008-04-08 12 316 val_316
+2008-04-08 12 57 val_57
+2008-04-08 12 302 val_302
+2008-04-08 12 205 val_205
+2008-04-08 12 149 val_149
+2008-04-08 12 438 val_438
+2008-04-08 12 345 val_345
+2008-04-08 12 129 val_129
+2008-04-08 12 170 val_170
+2008-04-08 12 20 val_20
+2008-04-08 12 489 val_489
+2008-04-08 12 157 val_157
+2008-04-08 12 378 val_378
+2008-04-08 12 221 val_221
+2008-04-08 12 92 val_92
+2008-04-08 12 111 val_111
+2008-04-08 12 47 val_47
+2008-04-08 12 72 val_72
+2008-04-08 12 4 val_4
+2008-04-08 12 280 val_280
+2008-04-08 12 35 val_35
+2008-04-08 12 427 val_427
+2008-04-08 12 277 val_277
+2008-04-08 12 208 val_208
+2008-04-08 12 356 val_356
+2008-04-08 12 399 val_399
+2008-04-08 12 169 val_169
+2008-04-08 12 382 val_382
+2008-04-08 12 498 val_498
+2008-04-08 12 125 val_125
+2008-04-08 12 386 val_386
+2008-04-08 12 437 val_437
+2008-04-08 12 469 val_469
+2008-04-08 12 192 val_192
+2008-04-08 12 286 val_286
+2008-04-08 12 187 val_187
+2008-04-08 12 176 val_176
+2008-04-08 12 54 val_54
+2008-04-08 12 459 val_459
+2008-04-08 12 51 val_51
+2008-04-08 12 138 val_138
+2008-04-08 12 103 val_103
+2008-04-08 12 239 val_239
+2008-04-08 12 213 val_213
+2008-04-08 12 216 val_216
+2008-04-08 12 430 val_430
+2008-04-08 12 278 val_278
+2008-04-08 12 176 val_176
+2008-04-08 12 289 val_289
+2008-04-08 12 221 val_221
+2008-04-08 12 65 val_65
+2008-04-08 12 318 val_318
+2008-04-08 12 332 val_332
+2008-04-08 12 311 val_311
+2008-04-08 12 275 val_275
+2008-04-08 12 137 val_137
+2008-04-08 12 241 val_241
+2008-04-08 12 83 val_83
+2008-04-08 12 333 val_333
+2008-04-08 12 180 val_180
+2008-04-08 12 284 val_284
+2008-04-08 12 12 val_12
+2008-04-08 12 230 val_230
+2008-04-08 12 181 val_181
+2008-04-08 12 67 val_67
+2008-04-08 12 260 val_260
+2008-04-08 12 404 val_404
+2008-04-08 12 384 val_384
+2008-04-08 12 489 val_489
+2008-04-08 12 353 val_353
+2008-04-08 12 373 val_373
+2008-04-08 12 272 val_272
+2008-04-08 12 138 val_138
+2008-04-08 12 217 val_217
+2008-04-08 12 84 val_84
+2008-04-08 12 348 val_348
+2008-04-08 12 466 val_466
+2008-04-08 12 58 val_58
+2008-04-08 12 8 val_8
+2008-04-08 12 411 val_411
+2008-04-08 12 230 val_230
+2008-04-08 12 208 val_208
+2008-04-08 12 348 val_348
+2008-04-08 12 24 val_24
+2008-04-08 12 463 val_463
+2008-04-08 12 431 val_431
+2008-04-08 12 179 val_179
+2008-04-08 12 172 val_172
+2008-04-08 12 42 val_42
+2008-04-08 12 129 val_129
+2008-04-08 12 158 val_158
+2008-04-08 12 119 val_119
+2008-04-08 12 496 val_496
+2008-04-08 12 0 val_0
+2008-04-08 12 322 val_322
+2008-04-08 12 197 val_197
+2008-04-08 12 468 val_468
+2008-04-08 12 393 val_393
+2008-04-08 12 454 val_454
+2008-04-08 12 100 val_100
+2008-04-08 12 298 val_298
+2008-04-08 12 199 val_199
+2008-04-08 12 191 val_191
+2008-04-08 12 418 val_418
+2008-04-08 12 96 val_96
+2008-04-08 12 26 val_26
+2008-04-08 12 165 val_165
+2008-04-08 12 327 val_327
+2008-04-08 12 230 val_230
+2008-04-08 12 205 val_205
+2008-04-08 12 120 val_120
+2008-04-08 12 131 val_131
+2008-04-08 12 51 val_51
+2008-04-08 12 404 val_404
+2008-04-08 12 43 val_43
+2008-04-08 12 436 val_436
+2008-04-08 12 156 val_156
+2008-04-08 12 469 val_469
+2008-04-08 12 468 val_468
+2008-04-08 12 308 val_308
+2008-04-08 12 95 val_95
+2008-04-08 12 196 val_196
+2008-04-08 12 288 val_288
+2008-04-08 12 481 val_481
+2008-04-08 12 457 val_457
+2008-04-08 12 98 val_98
+2008-04-08 12 282 val_282
+2008-04-08 12 197 val_197
+2008-04-08 12 187 val_187
+2008-04-08 12 318 val_318
+2008-04-08 12 318 val_318
+2008-04-08 12 409 val_409
+2008-04-08 12 470 val_470
+2008-04-08 12 137 val_137
+2008-04-08 12 369 val_369
+2008-04-08 12 316 val_316
+2008-04-08 12 169 val_169
+2008-04-08 12 413 val_413
+2008-04-08 12 85 val_85
+2008-04-08 12 77 val_77
+2008-04-08 12 0 val_0
+2008-04-08 12 490 val_490
+2008-04-08 12 87 val_87
+2008-04-08 12 364 val_364
+2008-04-08 12 179 val_179
+2008-04-08 12 118 val_118
+2008-04-08 12 134 val_134
+2008-04-08 12 395 val_395
+2008-04-08 12 282 val_282
+2008-04-08 12 138 val_138
+2008-04-08 12 238 val_238
+2008-04-08 12 419 val_419
+2008-04-08 12 15 val_15
+2008-04-08 12 118 val_118
+2008-04-08 12 72 val_72
+2008-04-08 12 90 val_90
+2008-04-08 12 307 val_307
+2008-04-08 12 19 val_19
+2008-04-08 12 435 val_435
+2008-04-08 12 10 val_10
+2008-04-08 12 277 val_277
+2008-04-08 12 273 val_273
+2008-04-08 12 306 val_306
+2008-04-08 12 224 val_224
+2008-04-08 12 309 val_309
+2008-04-08 12 389 val_389
+2008-04-08 12 327 val_327
+2008-04-08 12 242 val_242
+2008-04-08 12 369 val_369
+2008-04-08 12 392 val_392
+2008-04-08 12 272 val_272
+2008-04-08 12 331 val_331
+2008-04-08 12 401 val_401
+2008-04-08 12 242 val_242
+2008-04-08 12 452 val_452
+2008-04-08 12 177 val_177
+2008-04-08 12 226 val_226
+2008-04-08 12 5 val_5
+2008-04-08 12 497 val_497
+2008-04-08 12 402 val_402
+2008-04-08 12 396 val_396
+2008-04-08 12 317 val_317
+2008-04-08 12 395 val_395
+2008-04-08 12 58 val_58
+2008-04-08 12 35 val_35
+2008-04-08 12 336 val_336
+2008-04-08 12 95 val_95
+2008-04-08 12 11 val_11
+2008-04-08 12 168 val_168
+2008-04-08 12 34 val_34
+2008-04-08 12 229 val_229
+2008-04-08 12 233 val_233
+2008-04-08 12 143 val_143
+2008-04-08 12 472 val_472
+2008-04-08 12 322 val_322
+2008-04-08 12 498 val_498
+2008-04-08 12 160 val_160
+2008-04-08 12 195 val_195
+2008-04-08 12 42 val_42
+2008-04-08 12 321 val_321
+2008-04-08 12 430 val_430
+2008-04-08 12 119 val_119
+2008-04-08 12 489 val_489
+2008-04-08 12 458 val_458
+2008-04-08 12 78 val_78
+2008-04-08 12 76 val_76
+2008-04-08 12 41 val_41
+2008-04-08 12 223 val_223
+2008-04-08 12 492 val_492
+2008-04-08 12 149 val_149
+2008-04-08 12 449 val_449
+2008-04-08 12 218 val_218
+2008-04-08 12 228 val_228
+2008-04-08 12 138 val_138
+2008-04-08 12 453 val_453
+2008-04-08 12 30 val_30
+2008-04-08 12 209 val_209
+2008-04-08 12 64 val_64
+2008-04-08 12 468 val_468
+2008-04-08 12 76 val_76
+2008-04-08 12 74 val_74
+2008-04-08 12 342 val_342
+2008-04-08 12 69 val_69
+2008-04-08 12 230 val_230
+2008-04-08 12 33 val_33
+2008-04-08 12 368 val_368
+2008-04-08 12 103 val_103
+2008-04-08 12 296 val_296
+2008-04-08 12 113 val_113
+2008-04-08 12 216 val_216
+2008-04-08 12 367 val_367
+2008-04-08 12 344 val_344
+2008-04-08 12 167 val_167
+2008-04-08 12 274 val_274
+2008-04-08 12 219 val_219
+2008-04-08 12 239 val_239
+2008-04-08 12 485 val_485
+2008-04-08 12 116 val_116
+2008-04-08 12 223 val_223
+2008-04-08 12 256 val_256
+2008-04-08 12 263 val_263
+2008-04-08 12 70 val_70
+2008-04-08 12 487 val_487
+2008-04-08 12 480 val_480
+2008-04-08 12 401 val_401
+2008-04-08 12 288 val_288
+2008-04-08 12 191 val_191
+2008-04-08 12 5 val_5
+2008-04-08 12 244 val_244
+2008-04-08 12 438 val_438
+2008-04-08 12 128 val_128
+2008-04-08 12 467 val_467
+2008-04-08 12 432 val_432
+2008-04-08 12 202 val_202
+2008-04-08 12 316 val_316
+2008-04-08 12 229 val_229
+2008-04-08 12 469 val_469
+2008-04-08 12 463 val_463
+2008-04-08 12 280 val_280
+2008-04-08 12 2 val_2
+2008-04-08 12 35 val_35
+2008-04-08 12 283 val_283
+2008-04-08 12 331 val_331
+2008-04-08 12 235 val_235
+2008-04-08 12 80 val_80
+2008-04-08 12 44 val_44
+2008-04-08 12 193 val_193
+2008-04-08 12 321 val_321
+2008-04-08 12 335 val_335
+2008-04-08 12 104 val_104
+2008-04-08 12 466 val_466
+2008-04-08 12 366 val_366
+2008-04-08 12 175 val_175
+2008-04-08 12 403 val_403
+2008-04-08 12 483 val_483
+2008-04-08 12 53 val_53
+2008-04-08 12 105 val_105
+2008-04-08 12 257 val_257
+2008-04-08 12 406 val_406
+2008-04-08 12 409 val_409
+2008-04-08 12 190 val_190
+2008-04-08 12 406 val_406
+2008-04-08 12 401 val_401
+2008-04-08 12 114 val_114
+2008-04-08 12 258 val_258
+2008-04-08 12 90 val_90
+2008-04-08 12 203 val_203
+2008-04-08 12 262 val_262
+2008-04-08 12 348 val_348
+2008-04-08 12 424 val_424
+2008-04-08 12 12 val_12
+2008-04-08 12 396 val_396
+2008-04-08 12 201 val_201
+2008-04-08 12 217 val_217
+2008-04-08 12 164 val_164
+2008-04-08 12 431 val_431
+2008-04-08 12 454 val_454
+2008-04-08 12 478 val_478
+2008-04-08 12 298 val_298
+2008-04-08 12 125 val_125
+2008-04-08 12 431 val_431
+2008-04-08 12 164 val_164
+2008-04-08 12 424 val_424
+2008-04-08 12 187 val_187
+2008-04-08 12 382 val_382
+2008-04-08 12 5 val_5
+2008-04-08 12 70 val_70
+2008-04-08 12 397 val_397
+2008-04-08 12 480 val_480
+2008-04-08 12 291 val_291
+2008-04-08 12 24 val_24
+2008-04-08 12 351 val_351
+2008-04-08 12 255 val_255
+2008-04-08 12 104 val_104
+2008-04-08 12 70 val_70
+2008-04-08 12 163 val_163
+2008-04-08 12 438 val_438
+2008-04-08 12 119 val_119
+2008-04-08 12 414 val_414
+2008-04-08 12 200 val_200
+2008-04-08 12 491 val_491
+2008-04-08 12 237 val_237
+2008-04-08 12 439 val_439
+2008-04-08 12 360 val_360
+2008-04-08 12 248 val_248
+2008-04-08 12 479 val_479
+2008-04-08 12 305 val_305
+2008-04-08 12 417 val_417
+2008-04-08 12 199 val_199
+2008-04-08 12 444 val_444
+2008-04-08 12 120 val_120
+2008-04-08 12 429 val_429
+2008-04-08 12 169 val_169
+2008-04-08 12 443 val_443
+2008-04-08 12 323 val_323
+2008-04-08 12 325 val_325
+2008-04-08 12 277 val_277
+2008-04-08 12 230 val_230
+2008-04-08 12 478 val_478
+2008-04-08 12 178 val_178
+2008-04-08 12 468 val_468
+2008-04-08 12 310 val_310
+2008-04-08 12 317 val_317
+2008-04-08 12 333 val_333
+2008-04-08 12 493 val_493
+2008-04-08 12 460 val_460
+2008-04-08 12 207 val_207
+2008-04-08 12 249 val_249
+2008-04-08 12 265 val_265
+2008-04-08 12 480 val_480
+2008-04-08 12 83 val_83
+2008-04-08 12 136 val_136
+2008-04-08 12 353 val_353
+2008-04-08 12 172 val_172
+2008-04-08 12 214 val_214
+2008-04-08 12 462 val_462
+2008-04-08 12 233 val_233
+2008-04-08 12 406 val_406
+2008-04-08 12 133 val_133
+2008-04-08 12 175 val_175
+2008-04-08 12 189 val_189
+2008-04-08 12 454 val_454
+2008-04-08 12 375 val_375
+2008-04-08 12 401 val_401
+2008-04-08 12 421 val_421
+2008-04-08 12 407 val_407
+2008-04-08 12 384 val_384
+2008-04-08 12 256 val_256
+2008-04-08 12 26 val_26
+2008-04-08 12 134 val_134
+2008-04-08 12 67 val_67
+2008-04-08 12 384 val_384
+2008-04-08 12 379 val_379
+2008-04-08 12 18 val_18
+2008-04-08 12 462 val_462
+2008-04-08 12 492 val_492
+2008-04-08 12 100 val_100
+2008-04-08 12 298 val_298
+2008-04-08 12 9 val_9
+2008-04-08 12 341 val_341
+2008-04-08 12 498 val_498
+2008-04-08 12 146 val_146
+2008-04-08 12 458 val_458
+2008-04-08 12 362 val_362
+2008-04-08 12 186 val_186
+2008-04-08 12 285 val_285
+2008-04-08 12 348 val_348
+2008-04-08 12 167 val_167
+2008-04-08 12 18 val_18
+2008-04-08 12 273 val_273
+2008-04-08 12 183 val_183
+2008-04-08 12 281 val_281
+2008-04-08 12 344 val_344
+2008-04-08 12 97 val_97
+2008-04-08 12 469 val_469
+2008-04-08 12 315 val_315
+2008-04-08 12 84 val_84
+2008-04-08 12 28 val_28
+2008-04-08 12 37 val_37
+2008-04-08 12 448 val_448
+2008-04-08 12 152 val_152
+2008-04-08 12 348 val_348
+2008-04-08 12 307 val_307
+2008-04-08 12 194 val_194
+2008-04-08 12 414 val_414
+2008-04-08 12 477 val_477
+2008-04-08 12 222 val_222
+2008-04-08 12 126 val_126
+2008-04-08 12 90 val_90
+2008-04-08 12 169 val_169
+2008-04-08 12 403 val_403
+2008-04-08 12 400 val_400
+2008-04-08 12 200 val_200
+2008-04-08 12 97 val_97
+2008-04-09 11 238 val_238
+2008-04-09 11 86 val_86
+2008-04-09 11 311 val_311
+2008-04-09 11 27 val_27
+2008-04-09 11 165 val_165
+2008-04-09 11 409 val_409
+2008-04-09 11 255 val_255
+2008-04-09 11 278 val_278
+2008-04-09 11 98 val_98
+2008-04-09 11 484 val_484
+2008-04-09 11 265 val_265
+2008-04-09 11 193 val_193
+2008-04-09 11 401 val_401
+2008-04-09 11 150 val_150
+2008-04-09 11 273 val_273
+2008-04-09 11 224 val_224
+2008-04-09 11 369 val_369
+2008-04-09 11 66 val_66
+2008-04-09 11 128 val_128
+2008-04-09 11 213 val_213
+2008-04-09 11 146 val_146
+2008-04-09 11 406 val_406
+2008-04-09 11 429 val_429
+2008-04-09 11 374 val_374
+2008-04-09 11 152 val_152
+2008-04-09 11 469 val_469
+2008-04-09 11 145 val_145
+2008-04-09 11 495 val_495
+2008-04-09 11 37 val_37
+2008-04-09 11 327 val_327
+2008-04-09 11 281 val_281
+2008-04-09 11 277 val_277
+2008-04-09 11 209 val_209
+2008-04-09 11 15 val_15
+2008-04-09 11 82 val_82
+2008-04-09 11 403 val_403
+2008-04-09 11 166 val_166
+2008-04-09 11 417 val_417
+2008-04-09 11 430 val_430
+2008-04-09 11 252 val_252
+2008-04-09 11 292 val_292
+2008-04-09 11 219 val_219
+2008-04-09 11 287 val_287
+2008-04-09 11 153 val_153
+2008-04-09 11 193 val_193
+2008-04-09 11 338 val_338
+2008-04-09 11 446 val_446
+2008-04-09 11 459 val_459
+2008-04-09 11 394 val_394
+2008-04-09 11 237 val_237
+2008-04-09 11 482 val_482
+2008-04-09 11 174 val_174
+2008-04-09 11 413 val_413
+2008-04-09 11 494 val_494
+2008-04-09 11 207 val_207
+2008-04-09 11 199 val_199
+2008-04-09 11 466 val_466
+2008-04-09 11 208 val_208
+2008-04-09 11 174 val_174
+2008-04-09 11 399 val_399
+2008-04-09 11 396 val_396
+2008-04-09 11 247 val_247
+2008-04-09 11 417 val_417
+2008-04-09 11 489 val_489
+2008-04-09 11 162 val_162
+2008-04-09 11 377 val_377
+2008-04-09 11 397 val_397
+2008-04-09 11 309 val_309
+2008-04-09 11 365 val_365
+2008-04-09 11 266 val_266
+2008-04-09 11 439 val_439
+2008-04-09 11 342 val_342
+2008-04-09 11 367 val_367
+2008-04-09 11 325 val_325
+2008-04-09 11 167 val_167
+2008-04-09 11 195 val_195
+2008-04-09 11 475 val_475
+2008-04-09 11 17 val_17
+2008-04-09 11 113 val_113
+2008-04-09 11 155 val_155
+2008-04-09 11 203 val_203
+2008-04-09 11 339 val_339
+2008-04-09 11 0 val_0
+2008-04-09 11 455 val_455
+2008-04-09 11 128 val_128
+2008-04-09 11 311 val_311
+2008-04-09 11 316 val_316
+2008-04-09 11 57 val_57
+2008-04-09 11 302 val_302
+2008-04-09 11 205 val_205
+2008-04-09 11 149 val_149
+2008-04-09 11 438 val_438
+2008-04-09 11 345 val_345
+2008-04-09 11 129 val_129
+2008-04-09 11 170 val_170
+2008-04-09 11 20 val_20
+2008-04-09 11 489 val_489
+2008-04-09 11 157 val_157
+2008-04-09 11 378 val_378
+2008-04-09 11 221 val_221
+2008-04-09 11 92 val_92
+2008-04-09 11 111 val_111
+2008-04-09 11 47 val_47
+2008-04-09 11 72 val_72
+2008-04-09 11 4 val_4
+2008-04-09 11 280 val_280
+2008-04-09 11 35 val_35
+2008-04-09 11 427 val_427
+2008-04-09 11 277 val_277
+2008-04-09 11 208 val_208
+2008-04-09 11 356 val_356
+2008-04-09 11 399 val_399
+2008-04-09 11 169 val_169
+2008-04-09 11 382 val_382
+2008-04-09 11 498 val_498
+2008-04-09 11 125 val_125
+2008-04-09 11 386 val_386
+2008-04-09 11 437 val_437
+2008-04-09 11 469 val_469
+2008-04-09 11 192 val_192
+2008-04-09 11 286 val_286
+2008-04-09 11 187 val_187
+2008-04-09 11 176 val_176
+2008-04-09 11 54 val_54
+2008-04-09 11 459 val_459
+2008-04-09 11 51 val_51
+2008-04-09 11 138 val_138
+2008-04-09 11 103 val_103
+2008-04-09 11 239 val_239
+2008-04-09 11 213 val_213
+2008-04-09 11 216 val_216
+2008-04-09 11 430 val_430
+2008-04-09 11 278 val_278
+2008-04-09 11 176 val_176
+2008-04-09 11 289 val_289
+2008-04-09 11 221 val_221
+2008-04-09 11 65 val_65
+2008-04-09 11 318 val_318
+2008-04-09 11 332 val_332
+2008-04-09 11 311 val_311
+2008-04-09 11 275 val_275
+2008-04-09 11 137 val_137
+2008-04-09 11 241 val_241
+2008-04-09 11 83 val_83
+2008-04-09 11 333 val_333
+2008-04-09 11 180 val_180
+2008-04-09 11 284 val_284
+2008-04-09 11 12 val_12
+2008-04-09 11 230 val_230
+2008-04-09 11 181 val_181
+2008-04-09 11 67 val_67
+2008-04-09 11 260 val_260
+2008-04-09 11 404 val_404
+2008-04-09 11 384 val_384
+2008-04-09 11 489 val_489
+2008-04-09 11 353 val_353
+2008-04-09 11 373 val_373
+2008-04-09 11 272 val_272
+2008-04-09 11 138 val_138
+2008-04-09 11 217 val_217
+2008-04-09 11 84 val_84
+2008-04-09 11 348 val_348
+2008-04-09 11 466 val_466
+2008-04-09 11 58 val_58
+2008-04-09 11 8 val_8
+2008-04-09 11 411 val_411
+2008-04-09 11 230 val_230
+2008-04-09 11 208 val_208
+2008-04-09 11 348 val_348
+2008-04-09 11 24 val_24
+2008-04-09 11 463 val_463
+2008-04-09 11 431 val_431
+2008-04-09 11 179 val_179
+2008-04-09 11 172 val_172
+2008-04-09 11 42 val_42
+2008-04-09 11 129 val_129
+2008-04-09 11 158 val_158
+2008-04-09 11 119 val_119
+2008-04-09 11 496 val_496
+2008-04-09 11 0 val_0
+2008-04-09 11 322 val_322
+2008-04-09 11 197 val_197
+2008-04-09 11 468 val_468
+2008-04-09 11 393 val_393
+2008-04-09 11 454 val_454
+2008-04-09 11 100 val_100
+2008-04-09 11 298 val_298
+2008-04-09 11 199 val_199
+2008-04-09 11 191 val_191
+2008-04-09 11 418 val_418
+2008-04-09 11 96 val_96
+2008-04-09 11 26 val_26
+2008-04-09 11 165 val_165
+2008-04-09 11 327 val_327
+2008-04-09 11 230 val_230
+2008-04-09 11 205 val_205
+2008-04-09 11 120 val_120
+2008-04-09 11 131 val_131
+2008-04-09 11 51 val_51
+2008-04-09 11 404 val_404
+2008-04-09 11 43 val_43
+2008-04-09 11 436 val_436
+2008-04-09 11 156 val_156
+2008-04-09 11 469 val_469
+2008-04-09 11 468 val_468
+2008-04-09 11 308 val_308
+2008-04-09 11 95 val_95
+2008-04-09 11 196 val_196
+2008-04-09 11 288 val_288
+2008-04-09 11 481 val_481
+2008-04-09 11 457 val_457
+2008-04-09 11 98 val_98
+2008-04-09 11 282 val_282
+2008-04-09 11 197 val_197
+2008-04-09 11 187 val_187
+2008-04-09 11 318 val_318
+2008-04-09 11 318 val_318
+2008-04-09 11 409 val_409
+2008-04-09 11 470 val_470
+2008-04-09 11 137 val_137
+2008-04-09 11 369 val_369
+2008-04-09 11 316 val_316
+2008-04-09 11 169 val_169
+2008-04-09 11 413 val_413
+2008-04-09 11 85 val_85
+2008-04-09 11 77 val_77
+2008-04-09 11 0 val_0
+2008-04-09 11 490 val_490
+2008-04-09 11 87 val_87
+2008-04-09 11 364 val_364
+2008-04-09 11 179 val_179
+2008-04-09 11 118 val_118
+2008-04-09 11 134 val_134
+2008-04-09 11 395 val_395
+2008-04-09 11 282 val_282
+2008-04-09 11 138 val_138
+2008-04-09 11 238 val_238
+2008-04-09 11 419 val_419
+2008-04-09 11 15 val_15
+2008-04-09 11 118 val_118
+2008-04-09 11 72 val_72
+2008-04-09 11 90 val_90
+2008-04-09 11 307 val_307
+2008-04-09 11 19 val_19
+2008-04-09 11 435 val_435
+2008-04-09 11 10 val_10
+2008-04-09 11 277 val_277
+2008-04-09 11 273 val_273
+2008-04-09 11 306 val_306
+2008-04-09 11 224 val_224
+2008-04-09 11 309 val_309
+2008-04-09 11 389 val_389
+2008-04-09 11 327 val_327
+2008-04-09 11 242 val_242
+2008-04-09 11 369 val_369
+2008-04-09 11 392 val_392
+2008-04-09 11 272 val_272
+2008-04-09 11 331 val_331
+2008-04-09 11 401 val_401
+2008-04-09 11 242 val_242
+2008-04-09 11 452 val_452
+2008-04-09 11 177 val_177
+2008-04-09 11 226 val_226
+2008-04-09 11 5 val_5
+2008-04-09 11 497 val_497
+2008-04-09 11 402 val_402
+2008-04-09 11 396 val_396
+2008-04-09 11 317 val_317
+2008-04-09 11 395 val_395
+2008-04-09 11 58 val_58
+2008-04-09 11 35 val_35
+2008-04-09 11 336 val_336
+2008-04-09 11 95 val_95
+2008-04-09 11 11 val_11
+2008-04-09 11 168 val_168
+2008-04-09 11 34 val_34
+2008-04-09 11 229 val_229
+2008-04-09 11 233 val_233
+2008-04-09 11 143 val_143
+2008-04-09 11 472 val_472
+2008-04-09 11 322 val_322
+2008-04-09 11 498 val_498
+2008-04-09 11 160 val_160
+2008-04-09 11 195 val_195
+2008-04-09 11 42 val_42
+2008-04-09 11 321 val_321
+2008-04-09 11 430 val_430
+2008-04-09 11 119 val_119
+2008-04-09 11 489 val_489
+2008-04-09 11 458 val_458
+2008-04-09 11 78 val_78
+2008-04-09 11 76 val_76
+2008-04-09 11 41 val_41
+2008-04-09 11 223 val_223
+2008-04-09 11 492 val_492
+2008-04-09 11 149 val_149
+2008-04-09 11 449 val_449
+2008-04-09 11 218 val_218
+2008-04-09 11 228 val_228
+2008-04-09 11 138 val_138
+2008-04-09 11 453 val_453
+2008-04-09 11 30 val_30
+2008-04-09 11 209 val_209
+2008-04-09 11 64 val_64
+2008-04-09 11 468 val_468
+2008-04-09 11 76 val_76
+2008-04-09 11 74 val_74
+2008-04-09 11 342 val_342
+2008-04-09 11 69 val_69
+2008-04-09 11 230 val_230
+2008-04-09 11 33 val_33
+2008-04-09 11 368 val_368
+2008-04-09 11 103 val_103
+2008-04-09 11 296 val_296
+2008-04-09 11 113 val_113
+2008-04-09 11 216 val_216
+2008-04-09 11 367 val_367
+2008-04-09 11 344 val_344
+2008-04-09 11 167 val_167
+2008-04-09 11 274 val_274
+2008-04-09 11 219 val_219
+2008-04-09 11 239 val_239
+2008-04-09 11 485 val_485
+2008-04-09 11 116 val_116
+2008-04-09 11 223 val_223
+2008-04-09 11 256 val_256
+2008-04-09 11 263 val_263
+2008-04-09 11 70 val_70
+2008-04-09 11 487 val_487
+2008-04-09 11 480 val_480
+2008-04-09 11 401 val_401
+2008-04-09 11 288 val_288
+2008-04-09 11 191 val_191
+2008-04-09 11 5 val_5
+2008-04-09 11 244 val_244
+2008-04-09 11 438 val_438
+2008-04-09 11 128 val_128
+2008-04-09 11 467 val_467
+2008-04-09 11 432 val_432
+2008-04-09 11 202 val_202
+2008-04-09 11 316 val_316
+2008-04-09 11 229 val_229
+2008-04-09 11 469 val_469
+2008-04-09 11 463 val_463
+2008-04-09 11 280 val_280
+2008-04-09 11 2 val_2
+2008-04-09 11 35 val_35
+2008-04-09 11 283 val_283
+2008-04-09 11 331 val_331
+2008-04-09 11 235 val_235
+2008-04-09 11 80 val_80
+2008-04-09 11 44 val_44
+2008-04-09 11 193 val_193
+2008-04-09 11 321 val_321
+2008-04-09 11 335 val_335
+2008-04-09 11 104 val_104
+2008-04-09 11 466 val_466
+2008-04-09 11 366 val_366
+2008-04-09 11 175 val_175
+2008-04-09 11 403 val_403
+2008-04-09 11 483 val_483
+2008-04-09 11 53 val_53
+2008-04-09 11 105 val_105
+2008-04-09 11 257 val_257
+2008-04-09 11 406 val_406
+2008-04-09 11 409 val_409
+2008-04-09 11 190 val_190
+2008-04-09 11 406 val_406
+2008-04-09 11 401 val_401
+2008-04-09 11 114 val_114
+2008-04-09 11 258 val_258
+2008-04-09 11 90 val_90
+2008-04-09 11 203 val_203
+2008-04-09 11 262 val_262
+2008-04-09 11 348 val_348
+2008-04-09 11 424 val_424
+2008-04-09 11 12 val_12
+2008-04-09 11 396 val_396
+2008-04-09 11 201 val_201
+2008-04-09 11 217 val_217
+2008-04-09 11 164 val_164
+2008-04-09 11 431 val_431
+2008-04-09 11 454 val_454
+2008-04-09 11 478 val_478
+2008-04-09 11 298 val_298
+2008-04-09 11 125 val_125
+2008-04-09 11 431 val_431
+2008-04-09 11 164 val_164
+2008-04-09 11 424 val_424
+2008-04-09 11 187 val_187
+2008-04-09 11 382 val_382
+2008-04-09 11 5 val_5
+2008-04-09 11 70 val_70
+2008-04-09 11 397 val_397
+2008-04-09 11 480 val_480
+2008-04-09 11 291 val_291
+2008-04-09 11 24 val_24
+2008-04-09 11 351 val_351
+2008-04-09 11 255 val_255
+2008-04-09 11 104 val_104
+2008-04-09 11 70 val_70
+2008-04-09 11 163 val_163
+2008-04-09 11 438 val_438
+2008-04-09 11 119 val_119
+2008-04-09 11 414 val_414
+2008-04-09 11 200 val_200
+2008-04-09 11 491 val_491
+2008-04-09 11 237 val_237
+2008-04-09 11 439 val_439
+2008-04-09 11 360 val_360
+2008-04-09 11 248 val_248
+2008-04-09 11 479 val_479
+2008-04-09 11 305 val_305
+2008-04-09 11 417 val_417
+2008-04-09 11 199 val_199
+2008-04-09 11 444 val_444
+2008-04-09 11 120 val_120
+2008-04-09 11 429 val_429
+2008-04-09 11 169 val_169
+2008-04-09 11 443 val_443
+2008-04-09 11 323 val_323
+2008-04-09 11 325 val_325
+2008-04-09 11 277 val_277
+2008-04-09 11 230 val_230
+2008-04-09 11 478 val_478
+2008-04-09 11 178 val_178
+2008-04-09 11 468 val_468
+2008-04-09 11 310 val_310
+2008-04-09 11 317 val_317
+2008-04-09 11 333 val_333
+2008-04-09 11 493 val_493
+2008-04-09 11 460 val_460
+2008-04-09 11 207 val_207
+2008-04-09 11 249 val_249
+2008-04-09 11 265 val_265
+2008-04-09 11 480 val_480
+2008-04-09 11 83 val_83
+2008-04-09 11 136 val_136
+2008-04-09 11 353 val_353
+2008-04-09 11 172 val_172
+2008-04-09 11 214 val_214
+2008-04-09 11 462 val_462
+2008-04-09 11 233 val_233
+2008-04-09 11 406 val_406
+2008-04-09 11 133 val_133
+2008-04-09 11 175 val_175
+2008-04-09 11 189 val_189
+2008-04-09 11 454 val_454
+2008-04-09 11 375 val_375
+2008-04-09 11 401 val_401
+2008-04-09 11 421 val_421
+2008-04-09 11 407 val_407
+2008-04-09 11 384 val_384
+2008-04-09 11 256 val_256
+2008-04-09 11 26 val_26
+2008-04-09 11 134 val_134
+2008-04-09 11 67 val_67
+2008-04-09 11 384 val_384
+2008-04-09 11 379 val_379
+2008-04-09 11 18 val_18
+2008-04-09 11 462 val_462
+2008-04-09 11 492 val_492
+2008-04-09 11 100 val_100
+2008-04-09 11 298 val_298
+2008-04-09 11 9 val_9
+2008-04-09 11 341 val_341
+2008-04-09 11 498 val_498
+2008-04-09 11 146 val_146
+2008-04-09 11 458 val_458
+2008-04-09 11 362 val_362
+2008-04-09 11 186 val_186
+2008-04-09 11 285 val_285
+2008-04-09 11 348 val_348
+2008-04-09 11 167 val_167
+2008-04-09 11 18 val_18
+2008-04-09 11 273 val_273
+2008-04-09 11 183 val_183
+2008-04-09 11 281 val_281
+2008-04-09 11 344 val_344
+2008-04-09 11 97 val_97
+2008-04-09 11 469 val_469
+2008-04-09 11 315 val_315
+2008-04-09 11 84 val_84
+2008-04-09 11 28 val_28
+2008-04-09 11 37 val_37
+2008-04-09 11 448 val_448
+2008-04-09 11 152 val_152
+2008-04-09 11 348 val_348
+2008-04-09 11 307 val_307
+2008-04-09 11 194 val_194
+2008-04-09 11 414 val_414
+2008-04-09 11 477 val_477
+2008-04-09 11 222 val_222
+2008-04-09 11 126 val_126
+2008-04-09 11 90 val_90
+2008-04-09 11 169 val_169
+2008-04-09 11 403 val_403
+2008-04-09 11 400 val_400
+2008-04-09 11 200 val_200
+2008-04-09 11 97 val_97
+2008-04-09 12 238 val_238
+2008-04-09 12 86 val_86
+2008-04-09 12 311 val_311
+2008-04-09 12 27 val_27
+2008-04-09 12 165 val_165
+2008-04-09 12 409 val_409
+2008-04-09 12 255 val_255
+2008-04-09 12 278 val_278
+2008-04-09 12 98 val_98
+2008-04-09 12 484 val_484
+2008-04-09 12 265 val_265
+2008-04-09 12 193 val_193
+2008-04-09 12 401 val_401
+2008-04-09 12 150 val_150
+2008-04-09 12 273 val_273
+2008-04-09 12 224 val_224
+2008-04-09 12 369 val_369
+2008-04-09 12 66 val_66
+2008-04-09 12 128 val_128
+2008-04-09 12 213 val_213
+2008-04-09 12 146 val_146
+2008-04-09 12 406 val_406
+2008-04-09 12 429 val_429
+2008-04-09 12 374 val_374
+2008-04-09 12 152 val_152
+2008-04-09 12 469 val_469
+2008-04-09 12 145 val_145
+2008-04-09 12 495 val_495
+2008-04-09 12 37 val_37
+2008-04-09 12 327 val_327
+2008-04-09 12 281 val_281
+2008-04-09 12 277 val_277
+2008-04-09 12 209 val_209
+2008-04-09 12 15 val_15
+2008-04-09 12 82 val_82
+2008-04-09 12 403 val_403
+2008-04-09 12 166 val_166
+2008-04-09 12 417 val_417
+2008-04-09 12 430 val_430
+2008-04-09 12 252 val_252
+2008-04-09 12 292 val_292
+2008-04-09 12 219 val_219
+2008-04-09 12 287 val_287
+2008-04-09 12 153 val_153
+2008-04-09 12 193 val_193
+2008-04-09 12 338 val_338
+2008-04-09 12 446 val_446
+2008-04-09 12 459 val_459
+2008-04-09 12 394 val_394
+2008-04-09 12 237 val_237
+2008-04-09 12 482 val_482
+2008-04-09 12 174 val_174
+2008-04-09 12 413 val_413
+2008-04-09 12 494 val_494
+2008-04-09 12 207 val_207
+2008-04-09 12 199 val_199
+2008-04-09 12 466 val_466
+2008-04-09 12 208 val_208
+2008-04-09 12 174 val_174
+2008-04-09 12 399 val_399
+2008-04-09 12 396 val_396
+2008-04-09 12 247 val_247
+2008-04-09 12 417 val_417
+2008-04-09 12 489 val_489
+2008-04-09 12 162 val_162
+2008-04-09 12 377 val_377
+2008-04-09 12 397 val_397
+2008-04-09 12 309 val_309
+2008-04-09 12 365 val_365
+2008-04-09 12 266 val_266
+2008-04-09 12 439 val_439
+2008-04-09 12 342 val_342
+2008-04-09 12 367 val_367
+2008-04-09 12 325 val_325
+2008-04-09 12 167 val_167
+2008-04-09 12 195 val_195
+2008-04-09 12 475 val_475
+2008-04-09 12 17 val_17
+2008-04-09 12 113 val_113
+2008-04-09 12 155 val_155
+2008-04-09 12 203 val_203
+2008-04-09 12 339 val_339
+2008-04-09 12 0 val_0
+2008-04-09 12 455 val_455
+2008-04-09 12 128 val_128
+2008-04-09 12 311 val_311
+2008-04-09 12 316 val_316
+2008-04-09 12 57 val_57
+2008-04-09 12 302 val_302
+2008-04-09 12 205 val_205
+2008-04-09 12 149 val_149
+2008-04-09 12 438 val_438
+2008-04-09 12 345 val_345
+2008-04-09 12 129 val_129
+2008-04-09 12 170 val_170
+2008-04-09 12 20 val_20
+2008-04-09 12 489 val_489
+2008-04-09 12 157 val_157
+2008-04-09 12 378 val_378
+2008-04-09 12 221 val_221
+2008-04-09 12 92 val_92
+2008-04-09 12 111 val_111
+2008-04-09 12 47 val_47
+2008-04-09 12 72 val_72
+2008-04-09 12 4 val_4
+2008-04-09 12 280 val_280
+2008-04-09 12 35 val_35
+2008-04-09 12 427 val_427
+2008-04-09 12 277 val_277
+2008-04-09 12 208 val_208
+2008-04-09 12 356 val_356
+2008-04-09 12 399 val_399
+2008-04-09 12 169 val_169
+2008-04-09 12 382 val_382
+2008-04-09 12 498 val_498
+2008-04-09 12 125 val_125
+2008-04-09 12 386 val_386
+2008-04-09 12 437 val_437
+2008-04-09 12 469 val_469
+2008-04-09 12 192 val_192
+2008-04-09 12 286 val_286
+2008-04-09 12 187 val_187
+2008-04-09 12 176 val_176
+2008-04-09 12 54 val_54
+2008-04-09 12 459 val_459
+2008-04-09 12 51 val_51
+2008-04-09 12 138 val_138
+2008-04-09 12 103 val_103
+2008-04-09 12 239 val_239
+2008-04-09 12 213 val_213
+2008-04-09 12 216 val_216
+2008-04-09 12 430 val_430
+2008-04-09 12 278 val_278
+2008-04-09 12 176 val_176
+2008-04-09 12 289 val_289
+2008-04-09 12 221 val_221
+2008-04-09 12 65 val_65
+2008-04-09 12 318 val_318
+2008-04-09 12 332 val_332
+2008-04-09 12 311 val_311
+2008-04-09 12 275 val_275
+2008-04-09 12 137 val_137
+2008-04-09 12 241 val_241
+2008-04-09 12 83 val_83
+2008-04-09 12 333 val_333
+2008-04-09 12 180 val_180
+2008-04-09 12 284 val_284
+2008-04-09 12 12 val_12
+2008-04-09 12 230 val_230
+2008-04-09 12 181 val_181
+2008-04-09 12 67 val_67
+2008-04-09 12 260 val_260
+2008-04-09 12 404 val_404
+2008-04-09 12 384 val_384
+2008-04-09 12 489 val_489
+2008-04-09 12 353 val_353
+2008-04-09 12 373 val_373
+2008-04-09 12 272 val_272
+2008-04-09 12 138 val_138
+2008-04-09 12 217 val_217
+2008-04-09 12 84 val_84
+2008-04-09 12 348 val_348
+2008-04-09 12 466 val_466
+2008-04-09 12 58 val_58
+2008-04-09 12 8 val_8
+2008-04-09 12 411 val_411
+2008-04-09 12 230 val_230
+2008-04-09 12 208 val_208
+2008-04-09 12 348 val_348
+2008-04-09 12 24 val_24
+2008-04-09 12 463 val_463
+2008-04-09 12 431 val_431
+2008-04-09 12 179 val_179
+2008-04-09 12 172 val_172
+2008-04-09 12 42 val_42
+2008-04-09 12 129 val_129
+2008-04-09 12 158 val_158
+2008-04-09 12 119 val_119
+2008-04-09 12 496 val_496
+2008-04-09 12 0 val_0
+2008-04-09 12 322 val_322
+2008-04-09 12 197 val_197
+2008-04-09 12 468 val_468
+2008-04-09 12 393 val_393
+2008-04-09 12 454 val_454
+2008-04-09 12 100 val_100
+2008-04-09 12 298 val_298
+2008-04-09 12 199 val_199
+2008-04-09 12 191 val_191
+2008-04-09 12 418 val_418
+2008-04-09 12 96 val_96
+2008-04-09 12 26 val_26
+2008-04-09 12 165 val_165
+2008-04-09 12 327 val_327
+2008-04-09 12 230 val_230
+2008-04-09 12 205 val_205
+2008-04-09 12 120 val_120
+2008-04-09 12 131 val_131
+2008-04-09 12 51 val_51
+2008-04-09 12 404 val_404
+2008-04-09 12 43 val_43
+2008-04-09 12 436 val_436
+2008-04-09 12 156 val_156
+2008-04-09 12 469 val_469
+2008-04-09 12 468 val_468
+2008-04-09 12 308 val_308
+2008-04-09 12 95 val_95
+2008-04-09 12 196 val_196
+2008-04-09 12 288 val_288
+2008-04-09 12 481 val_481
+2008-04-09 12 457 val_457
+2008-04-09 12 98 val_98
+2008-04-09 12 282 val_282
+2008-04-09 12 197 val_197
+2008-04-09 12 187 val_187
+2008-04-09 12 318 val_318
+2008-04-09 12 318 val_318
+2008-04-09 12 409 val_409
+2008-04-09 12 470 val_470
+2008-04-09 12 137 val_137
+2008-04-09 12 369 val_369
+2008-04-09 12 316 val_316
+2008-04-09 12 169 val_169
+2008-04-09 12 413 val_413
+2008-04-09 12 85 val_85
+2008-04-09 12 77 val_77
+2008-04-09 12 0 val_0
+2008-04-09 12 490 val_490
+2008-04-09 12 87 val_87
+2008-04-09 12 364 val_364
+2008-04-09 12 179 val_179
+2008-04-09 12 118 val_118
+2008-04-09 12 134 val_134
+2008-04-09 12 395 val_395
+2008-04-09 12 282 val_282
+2008-04-09 12 138 val_138
+2008-04-09 12 238 val_238
+2008-04-09 12 419 val_419
+2008-04-09 12 15 val_15
+2008-04-09 12 118 val_118
+2008-04-09 12 72 val_72
+2008-04-09 12 90 val_90
+2008-04-09 12 307 val_307
+2008-04-09 12 19 val_19
+2008-04-09 12 435 val_435
+2008-04-09 12 10 val_10
+2008-04-09 12 277 val_277
+2008-04-09 12 273 val_273
+2008-04-09 12 306 val_306
+2008-04-09 12 224 val_224
+2008-04-09 12 309 val_309
+2008-04-09 12 389 val_389
+2008-04-09 12 327 val_327
+2008-04-09 12 242 val_242
+2008-04-09 12 369 val_369
+2008-04-09 12 392 val_392
+2008-04-09 12 272 val_272
+2008-04-09 12 331 val_331
+2008-04-09 12 401 val_401
+2008-04-09 12 242 val_242
+2008-04-09 12 452 val_452
+2008-04-09 12 177 val_177
+2008-04-09 12 226 val_226
+2008-04-09 12 5 val_5
+2008-04-09 12 497 val_497
+2008-04-09 12 402 val_402
+2008-04-09 12 396 val_396
+2008-04-09 12 317 val_317
+2008-04-09 12 395 val_395
+2008-04-09 12 58 val_58
+2008-04-09 12 35 val_35
+2008-04-09 12 336 val_336
+2008-04-09 12 95 val_95
+2008-04-09 12 11 val_11
+2008-04-09 12 168 val_168
+2008-04-09 12 34 val_34
+2008-04-09 12 229 val_229
+2008-04-09 12 233 val_233
+2008-04-09 12 143 val_143
+2008-04-09 12 472 val_472
+2008-04-09 12 322 val_322
+2008-04-09 12 498 val_498
+2008-04-09 12 160 val_160
+2008-04-09 12 195 val_195
+2008-04-09 12 42 val_42
+2008-04-09 12 321 val_321
+2008-04-09 12 430 val_430
+2008-04-09 12 119 val_119
+2008-04-09 12 489 val_489
+2008-04-09 12 458 val_458
+2008-04-09 12 78 val_78
+2008-04-09 12 76 val_76
+2008-04-09 12 41 val_41
+2008-04-09 12 223 val_223
+2008-04-09 12 492 val_492
+2008-04-09 12 149 val_149
+2008-04-09 12 449 val_449
+2008-04-09 12 218 val_218
+2008-04-09 12 228 val_228
+2008-04-09 12 138 val_138
+2008-04-09 12 453 val_453
+2008-04-09 12 30 val_30
+2008-04-09 12 209 val_209
+2008-04-09 12 64 val_64
+2008-04-09 12 468 val_468
+2008-04-09 12 76 val_76
+2008-04-09 12 74 val_74
+2008-04-09 12 342 val_342
+2008-04-09 12 69 val_69
+2008-04-09 12 230 val_230
+2008-04-09 12 33 val_33
+2008-04-09 12 368 val_368
+2008-04-09 12 103 val_103
+2008-04-09 12 296 val_296
+2008-04-09 12 113 val_113
+2008-04-09 12 216 val_216
+2008-04-09 12 367 val_367
+2008-04-09 12 344 val_344
+2008-04-09 12 167 val_167
+2008-04-09 12 274 val_274
+2008-04-09 12 219 val_219
+2008-04-09 12 239 val_239
+2008-04-09 12 485 val_485
+2008-04-09 12 116 val_116
+2008-04-09 12 223 val_223
+2008-04-09 12 256 val_256
+2008-04-09 12 263 val_263
+2008-04-09 12 70 val_70
+2008-04-09 12 487 val_487
+2008-04-09 12 480 val_480
+2008-04-09 12 401 val_401
+2008-04-09 12 288 val_288
+2008-04-09 12 191 val_191
+2008-04-09 12 5 val_5
+2008-04-09 12 244 val_244
+2008-04-09 12 438 val_438
+2008-04-09 12 128 val_128
+2008-04-09 12 467 val_467
+2008-04-09 12 432 val_432
+2008-04-09 12 202 val_202
+2008-04-09 12 316 val_316
+2008-04-09 12 229 val_229
+2008-04-09 12 469 val_469
+2008-04-09 12 463 val_463
+2008-04-09 12 280 val_280
+2008-04-09 12 2 val_2
+2008-04-09 12 35 val_35
+2008-04-09 12 283 val_283
+2008-04-09 12 331 val_331
+2008-04-09 12 235 val_235
+2008-04-09 12 80 val_80
+2008-04-09 12 44 val_44
+2008-04-09 12 193 val_193
+2008-04-09 12 321 val_321
+2008-04-09 12 335 val_335
+2008-04-09 12 104 val_104
+2008-04-09 12 466 val_466
+2008-04-09 12 366 val_366
+2008-04-09 12 175 val_175
+2008-04-09 12 403 val_403
+2008-04-09 12 483 val_483
+2008-04-09 12 53 val_53
+2008-04-09 12 105 val_105
+2008-04-09 12 257 val_257
+2008-04-09 12 406 val_406
+2008-04-09 12 409 val_409
+2008-04-09 12 190 val_190
+2008-04-09 12 406 val_406
+2008-04-09 12 401 val_401
+2008-04-09 12 114 val_114
+2008-04-09 12 258 val_258
+2008-04-09 12 90 val_90
+2008-04-09 12 203 val_203
+2008-04-09 12 262 val_262
+2008-04-09 12 348 val_348
+2008-04-09 12 424 val_424
+2008-04-09 12 12 val_12
+2008-04-09 12 396 val_396
+2008-04-09 12 201 val_201
+2008-04-09 12 217 val_217
+2008-04-09 12 164 val_164
+2008-04-09 12 431 val_431
+2008-04-09 12 454 val_454
+2008-04-09 12 478 val_478
+2008-04-09 12 298 val_298
+2008-04-09 12 125 val_125
+2008-04-09 12 431 val_431
+2008-04-09 12 164 val_164
+2008-04-09 12 424 val_424
+2008-04-09 12 187 val_187
+2008-04-09 12 382 val_382
+2008-04-09 12 5 val_5
+2008-04-09 12 70 val_70
+2008-04-09 12 397 val_397
+2008-04-09 12 480 val_480
+2008-04-09 12 291 val_291
+2008-04-09 12 24 val_24
+2008-04-09 12 351 val_351
+2008-04-09 12 255 val_255
+2008-04-09 12 104 val_104
+2008-04-09 12 70 val_70
+2008-04-09 12 163 val_163
+2008-04-09 12 438 val_438
+2008-04-09 12 119 val_119
+2008-04-09 12 414 val_414
+2008-04-09 12 200 val_200
+2008-04-09 12 491 val_491
+2008-04-09 12 237 val_237
+2008-04-09 12 439 val_439
+2008-04-09 12 360 val_360
+2008-04-09 12 248 val_248
+2008-04-09 12 479 val_479
+2008-04-09 12 305 val_305
+2008-04-09 12 417 val_417
+2008-04-09 12 199 val_199
+2008-04-09 12 444 val_444
+2008-04-09 12 120 val_120
+2008-04-09 12 429 val_429
+2008-04-09 12 169 val_169
+2008-04-09 12 443 val_443
+2008-04-09 12 323 val_323
+2008-04-09 12 325 val_325
+2008-04-09 12 277 val_277
+2008-04-09 12 230 val_230
+2008-04-09 12 478 val_478
+2008-04-09 12 178 val_178
+2008-04-09 12 468 val_468
+2008-04-09 12 310 val_310
+2008-04-09 12 317 val_317
+2008-04-09 12 333 val_333
+2008-04-09 12 493 val_493
+2008-04-09 12 460 val_460
+2008-04-09 12 207 val_207
+2008-04-09 12 249 val_249
+2008-04-09 12 265 val_265
+2008-04-09 12 480 val_480
+2008-04-09 12 83 val_83
+2008-04-09 12 136 val_136
+2008-04-09 12 353 val_353
+2008-04-09 12 172 val_172
+2008-04-09 12 214 val_214
+2008-04-09 12 462 val_462
+2008-04-09 12 233 val_233
+2008-04-09 12 406 val_406
+2008-04-09 12 133 val_133
+2008-04-09 12 175 val_175
+2008-04-09 12 189 val_189
+2008-04-09 12 454 val_454
+2008-04-09 12 375 val_375
+2008-04-09 12 401 val_401
+2008-04-09 12 421 val_421
+2008-04-09 12 407 val_407
+2008-04-09 12 384 val_384
+2008-04-09 12 256 val_256
+2008-04-09 12 26 val_26
+2008-04-09 12 134 val_134
+2008-04-09 12 67 val_67
+2008-04-09 12 384 val_384
+2008-04-09 12 379 val_379
+2008-04-09 12 18 val_18
+2008-04-09 12 462 val_462
+2008-04-09 12 492 val_492
+2008-04-09 12 100 val_100
+2008-04-09 12 298 val_298
+2008-04-09 12 9 val_9
+2008-04-09 12 341 val_341
+2008-04-09 12 498 val_498
+2008-04-09 12 146 val_146
+2008-04-09 12 458 val_458
+2008-04-09 12 362 val_362
+2008-04-09 12 186 val_186
+2008-04-09 12 285 val_285
+2008-04-09 12 348 val_348
+2008-04-09 12 167 val_167
+2008-04-09 12 18 val_18
+2008-04-09 12 273 val_273
+2008-04-09 12 183 val_183
+2008-04-09 12 281 val_281
+2008-04-09 12 344 val_344
+2008-04-09 12 97 val_97
+2008-04-09 12 469 val_469
+2008-04-09 12 315 val_315
+2008-04-09 12 84 val_84
+2008-04-09 12 28 val_28
+2008-04-09 12 37 val_37
+2008-04-09 12 448 val_448
+2008-04-09 12 152 val_152
+2008-04-09 12 348 val_348
+2008-04-09 12 307 val_307
+2008-04-09 12 194 val_194
+2008-04-09 12 414 val_414
+2008-04-09 12 477 val_477
+2008-04-09 12 222 val_222
+2008-04-09 12 126 val_126
+2008-04-09 12 90 val_90
+2008-04-09 12 169 val_169
+2008-04-09 12 403 val_403
+2008-04-09 12 400 val_400
+2008-04-09 12 200 val_200
+2008-04-09 12 97 val_97
diff --git a/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b
new file mode 100644
index 0000000000000..60878ffb77064
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b
@@ -0,0 +1 @@
+238 val_238
diff --git a/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b
new file mode 100644
index 0000000000000..60878ffb77064
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b
@@ -0,0 +1 @@
+238 val_238
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-0-86a409d8b868dc5f1a3bd1e04c2bc28c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-0-86a409d8b868dc5f1a3bd1e04c2bc28c
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-0-86a409d8b868dc5f1a3bd1e04c2bc28c
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-1-2b1df88619e34f221d39598b5cd73283 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-1-2b1df88619e34f221d39598b5cd73283
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-1-2b1df88619e34f221d39598b5cd73283
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-10-60eadbb52f8857830a3034952c631ace b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-10-60eadbb52f8857830a3034952c631ace
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-11-dbe79f90862dc5c6cc4a4fa4b4b6c655 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-11-dbe79f90862dc5c6cc4a4fa4b4b6c655
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-12-60018cae9a0476dc6a0ab4264310edb5 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-12-60018cae9a0476dc6a0ab4264310edb5
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-2-7562d4fee13f3ba935a2e824f86a4224 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-2-7562d4fee13f3ba935a2e824f86a4224
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-2-7562d4fee13f3ba935a2e824f86a4224
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-3-bdb30a5d6887ee4fb089f8676313eafd b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-3-bdb30a5d6887ee4fb089f8676313eafd
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-3-bdb30a5d6887ee4fb089f8676313eafd
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-4-10713b30ecb3c88acdd775bf9628c38c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-4-10713b30ecb3c88acdd775bf9628c38c
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-4-10713b30ecb3c88acdd775bf9628c38c
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-5-bab89dfffa77258e34a595e0e79986e3 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-5-bab89dfffa77258e34a595e0e79986e3
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-5-bab89dfffa77258e34a595e0e79986e3
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-6-6f53d5613262d393d82d159ec5dc16dc b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-6-6f53d5613262d393d82d159ec5dc16dc
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-6-6f53d5613262d393d82d159ec5dc16dc
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-7-ad4ddb5c5d6b994f4dba35f6162b6a9f b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-7-ad4ddb5c5d6b994f4dba35f6162b6a9f
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-8-f9dd797f1c90e2108cfee585f443c132 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-8-f9dd797f1c90e2108cfee585f443c132
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-9-22fdd8380f2652de2492b34a425d46d7 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-9-22fdd8380f2652de2492b34a425d46d7
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-0-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-0-7a9e67189d3d4151f23b12c22bde06b5
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-1-86a409d8b868dc5f1a3bd1e04c2bc28c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-1-86a409d8b868dc5f1a3bd1e04c2bc28c
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-1-86a409d8b868dc5f1a3bd1e04c2bc28c
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-10-22fdd8380f2652de2492b34a425d46d7 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-10-22fdd8380f2652de2492b34a425d46d7
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-11-60eadbb52f8857830a3034952c631ace b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-11-60eadbb52f8857830a3034952c631ace
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-12-dbe79f90862dc5c6cc4a4fa4b4b6c655 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-12-dbe79f90862dc5c6cc4a4fa4b4b6c655
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-13-60018cae9a0476dc6a0ab4264310edb5 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-13-60018cae9a0476dc6a0ab4264310edb5
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-2-2b1df88619e34f221d39598b5cd73283 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-2-2b1df88619e34f221d39598b5cd73283
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-2-2b1df88619e34f221d39598b5cd73283
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-3-7562d4fee13f3ba935a2e824f86a4224 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-3-7562d4fee13f3ba935a2e824f86a4224
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-3-7562d4fee13f3ba935a2e824f86a4224
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-4-bdb30a5d6887ee4fb089f8676313eafd b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-4-bdb30a5d6887ee4fb089f8676313eafd
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-4-bdb30a5d6887ee4fb089f8676313eafd
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-5-10713b30ecb3c88acdd775bf9628c38c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-5-10713b30ecb3c88acdd775bf9628c38c
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-5-10713b30ecb3c88acdd775bf9628c38c
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-6-bab89dfffa77258e34a595e0e79986e3 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-6-bab89dfffa77258e34a595e0e79986e3
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-6-bab89dfffa77258e34a595e0e79986e3
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-7-6f53d5613262d393d82d159ec5dc16dc b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-7-6f53d5613262d393d82d159ec5dc16dc
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-7-6f53d5613262d393d82d159ec5dc16dc
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-8-7a45282169e5a15d70ae0afb9e67ec9a b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-8-7a45282169e5a15d70ae0afb9e67ec9a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-9-f9dd797f1c90e2108cfee585f443c132 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-9-f9dd797f1c90e2108cfee585f443c132
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-0-48751533b44ea9e8ac3131767c2fed05 b/sql/hive/src/test/resources/golden/timestamp_comparison-0-48751533b44ea9e8ac3131767c2fed05
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-0-48751533b44ea9e8ac3131767c2fed05
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-1-60557e7bd2822c89fa8b076a9d0520fc b/sql/hive/src/test/resources/golden/timestamp_comparison-1-60557e7bd2822c89fa8b076a9d0520fc
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-1-60557e7bd2822c89fa8b076a9d0520fc
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-2-f96a9d88327951bd93f672dc2463ecd4 b/sql/hive/src/test/resources/golden/timestamp_comparison-2-f96a9d88327951bd93f672dc2463ecd4
new file mode 100644
index 0000000000000..27ba77ddaf615
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-2-f96a9d88327951bd93f672dc2463ecd4
@@ -0,0 +1 @@
+true
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-3-13e17ed811165196416f777cbc162592 b/sql/hive/src/test/resources/golden/timestamp_comparison-3-13e17ed811165196416f777cbc162592
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-3-13e17ed811165196416f777cbc162592
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-4-4fa8a36edbefde4427c2ab2cf30e6399 b/sql/hive/src/test/resources/golden/timestamp_comparison-4-4fa8a36edbefde4427c2ab2cf30e6399
new file mode 100644
index 0000000000000..27ba77ddaf615
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-4-4fa8a36edbefde4427c2ab2cf30e6399
@@ -0,0 +1 @@
+true
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-5-7e4fb6e8ba01df422e4c67e06a0c8453 b/sql/hive/src/test/resources/golden/timestamp_comparison-5-7e4fb6e8ba01df422e4c67e06a0c8453
new file mode 100644
index 0000000000000..27ba77ddaf615
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-5-7e4fb6e8ba01df422e4c67e06a0c8453
@@ -0,0 +1 @@
+true
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-6-8c8e73673a950f6b3d960b08fcea076f b/sql/hive/src/test/resources/golden/timestamp_comparison-6-8c8e73673a950f6b3d960b08fcea076f
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-6-8c8e73673a950f6b3d960b08fcea076f
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-7-510c0a2a57dc5df8588bd13c4152f8bc b/sql/hive/src/test/resources/golden/timestamp_comparison-7-510c0a2a57dc5df8588bd13c4152f8bc
new file mode 100644
index 0000000000000..27ba77ddaf615
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-7-510c0a2a57dc5df8588bd13c4152f8bc
@@ -0,0 +1 @@
+true
diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-8-659d5b1ae8200f13f265270e52a3dd65 b/sql/hive/src/test/resources/golden/timestamp_comparison-8-659d5b1ae8200f13f265270e52a3dd65
new file mode 100644
index 0000000000000..27ba77ddaf615
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/timestamp_comparison-8-659d5b1ae8200f13f265270e52a3dd65
@@ -0,0 +1 @@
+true
diff --git a/sql/hive/src/test/resources/golden/type_cast_1-0-60ea21e6e7d054a65f959fc89acf1b3d b/sql/hive/src/test/resources/golden/type_cast_1-0-60ea21e6e7d054a65f959fc89acf1b3d
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/type_cast_1-1-53a667981ad567b2ab977f67d65c5825 b/sql/hive/src/test/resources/golden/type_cast_1-1-53a667981ad567b2ab977f67d65c5825
new file mode 100644
index 0000000000000..7ed6ff82de6bc
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/type_cast_1-1-53a667981ad567b2ab977f67d65c5825
@@ -0,0 +1 @@
+5
diff --git a/sql/hive/src/test/resources/golden/udf_printf-0-e86d559aeb84a4cc017a103182c22bfb b/sql/hive/src/test/resources/golden/udf_printf-0-e86d559aeb84a4cc017a103182c22bfb
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/udf_printf-1-19c61fce27310ab2590062d643f7b26e b/sql/hive/src/test/resources/golden/udf_printf-1-19c61fce27310ab2590062d643f7b26e
new file mode 100644
index 0000000000000..1635ff88dd768
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_printf-1-19c61fce27310ab2590062d643f7b26e
@@ -0,0 +1 @@
+printf(String format, Obj... args) - function that can format strings according to printf-style format strings
diff --git a/sql/hive/src/test/resources/golden/udf_printf-2-25aa6950cae2bb781c336378f63ceaee b/sql/hive/src/test/resources/golden/udf_printf-2-25aa6950cae2bb781c336378f63ceaee
new file mode 100644
index 0000000000000..62440ee68e145
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_printf-2-25aa6950cae2bb781c336378f63ceaee
@@ -0,0 +1,4 @@
+printf(String format, Obj... args) - function that can format strings according to printf-style format strings
+Example:
+ > SELECT printf("Hello World %d %s", 100, "days")FROM src LIMIT 1;
+ "Hello World 100 days"
diff --git a/sql/hive/src/test/resources/golden/udf_printf-3-9c568a0473888396bd46507e8b330c36 b/sql/hive/src/test/resources/golden/udf_printf-3-9c568a0473888396bd46507e8b330c36
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/udf_printf-4-91728e546b450bdcbb05ef30f13be475 b/sql/hive/src/test/resources/golden/udf_printf-4-91728e546b450bdcbb05ef30f13be475
new file mode 100644
index 0000000000000..39cb945991403
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_printf-4-91728e546b450bdcbb05ef30f13be475
@@ -0,0 +1 @@
+Hello World 100 days
diff --git a/sql/hive/src/test/resources/golden/udf_printf-5-3141a0421605b091ee5a9e99d7d605fb b/sql/hive/src/test/resources/golden/udf_printf-5-3141a0421605b091ee5a9e99d7d605fb
new file mode 100644
index 0000000000000..04bf5e552a576
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_printf-5-3141a0421605b091ee5a9e99d7d605fb
@@ -0,0 +1 @@
+All Type Test: false, A, 15000, 1.234000e+01, +27183.2401, 2300.41, 32, corret, 0x1.002p8
diff --git a/sql/hive/src/test/resources/golden/udf_printf-6-ec37b73012f3cbbbc0422744b0db8294 b/sql/hive/src/test/resources/golden/udf_printf-6-ec37b73012f3cbbbc0422744b0db8294
new file mode 100644
index 0000000000000..2e9f7509968a3
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_printf-6-ec37b73012f3cbbbc0422744b0db8294
@@ -0,0 +1 @@
+Color red, String Null: null, number1 123456, number2 00089, Integer Null: null, hex 0xff, float 3.14 Double Null: null
diff --git a/sql/hive/src/test/resources/golden/udf_printf-7-5769f3a5b3300ca1d8b861229e976126 b/sql/hive/src/test/resources/golden/udf_printf-7-5769f3a5b3300ca1d8b861229e976126
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-10-51822ac740629bebd81d2abda6e1144 b/sql/hive/src/test/resources/golden/udf_to_boolean-10-51822ac740629bebd81d2abda6e1144
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-10-51822ac740629bebd81d2abda6e1144
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-11-441306cae24618c49ec63445a31bf16b b/sql/hive/src/test/resources/golden/udf_to_boolean-11-441306cae24618c49ec63445a31bf16b
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-11-441306cae24618c49ec63445a31bf16b
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-12-bfcc534e73e320a1cfad9c584678d870 b/sql/hive/src/test/resources/golden/udf_to_boolean-12-bfcc534e73e320a1cfad9c584678d870
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-12-bfcc534e73e320a1cfad9c584678d870
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-13-a2bddaa5db1841bb4617239b9f17a06d b/sql/hive/src/test/resources/golden/udf_to_boolean-13-a2bddaa5db1841bb4617239b9f17a06d
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-13-a2bddaa5db1841bb4617239b9f17a06d
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-14-773801b833cf72d35016916b786275b5 b/sql/hive/src/test/resources/golden/udf_to_boolean-14-773801b833cf72d35016916b786275b5
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-14-773801b833cf72d35016916b786275b5
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-15-4071ed0ff57b53963d5ee662fa9db0b0 b/sql/hive/src/test/resources/golden/udf_to_boolean-15-4071ed0ff57b53963d5ee662fa9db0b0
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-15-4071ed0ff57b53963d5ee662fa9db0b0
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-16-6b441df08afdc0c6c4a82670997dabb5 b/sql/hive/src/test/resources/golden/udf_to_boolean-16-6b441df08afdc0c6c4a82670997dabb5
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-16-6b441df08afdc0c6c4a82670997dabb5
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-17-85342c694d7f35e7eedb24e850d0c7df b/sql/hive/src/test/resources/golden/udf_to_boolean-17-85342c694d7f35e7eedb24e850d0c7df
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-17-85342c694d7f35e7eedb24e850d0c7df
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-18-fcd7af0e71d3e2d934239ba606e3ed87 b/sql/hive/src/test/resources/golden/udf_to_boolean-18-fcd7af0e71d3e2d934239ba606e3ed87
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-18-fcd7af0e71d3e2d934239ba606e3ed87
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-19-dcdb12fe551aa68a56921822f5d1a343 b/sql/hive/src/test/resources/golden/udf_to_boolean-19-dcdb12fe551aa68a56921822f5d1a343
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-19-dcdb12fe551aa68a56921822f5d1a343
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-20-131900d39d9a20b431731a32fb9715f8 b/sql/hive/src/test/resources/golden/udf_to_boolean-20-131900d39d9a20b431731a32fb9715f8
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-20-131900d39d9a20b431731a32fb9715f8
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-21-a5e28f4eb819e5a5e292e279f2990a7a b/sql/hive/src/test/resources/golden/udf_to_boolean-21-a5e28f4eb819e5a5e292e279f2990a7a
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-21-a5e28f4eb819e5a5e292e279f2990a7a
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-22-93278c10d642fa242f303d89b3b1961d b/sql/hive/src/test/resources/golden/udf_to_boolean-22-93278c10d642fa242f303d89b3b1961d
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-22-93278c10d642fa242f303d89b3b1961d
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-23-828558020ce907ffa7e847762a5e2358 b/sql/hive/src/test/resources/golden/udf_to_boolean-23-828558020ce907ffa7e847762a5e2358
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-23-828558020ce907ffa7e847762a5e2358
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-24-e8ca597d87932af16c0cf29d662e92da b/sql/hive/src/test/resources/golden/udf_to_boolean-24-e8ca597d87932af16c0cf29d662e92da
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-24-e8ca597d87932af16c0cf29d662e92da
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-25-86245727f90de9ce65a12c97a03a5635 b/sql/hive/src/test/resources/golden/udf_to_boolean-25-86245727f90de9ce65a12c97a03a5635
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-25-86245727f90de9ce65a12c97a03a5635
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-26-552d7ec5a4e0c93dc59a61973e2d63a2 b/sql/hive/src/test/resources/golden/udf_to_boolean-26-552d7ec5a4e0c93dc59a61973e2d63a2
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-26-552d7ec5a4e0c93dc59a61973e2d63a2
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-27-b61509b01b2fe3e7e4b72fedc74ff4f9 b/sql/hive/src/test/resources/golden/udf_to_boolean-27-b61509b01b2fe3e7e4b72fedc74ff4f9
new file mode 100644
index 0000000000000..7951defec192a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-27-b61509b01b2fe3e7e4b72fedc74ff4f9
@@ -0,0 +1 @@
+NULL
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-8-37229f303635a030f6cab20e0381f51f b/sql/hive/src/test/resources/golden/udf_to_boolean-8-37229f303635a030f6cab20e0381f51f
new file mode 100644
index 0000000000000..27ba77ddaf615
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-8-37229f303635a030f6cab20e0381f51f
@@ -0,0 +1 @@
+true
diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-9-be623247e4dbf119b43458b72d1be017 b/sql/hive/src/test/resources/golden/udf_to_boolean-9-be623247e4dbf119b43458b72d1be017
new file mode 100644
index 0000000000000..c508d5366f70b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_to_boolean-9-be623247e4dbf119b43458b72d1be017
@@ -0,0 +1 @@
+false
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
new file mode 100644
index 0000000000000..79ec1f1cde019
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.execution.SparkLogicalPlan
+import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
+import org.apache.spark.sql.hive.execution.HiveComparisonTest
+
+class CachedTableSuite extends HiveComparisonTest {
+ TestHive.loadTestTable("src")
+
+ test("cache table") {
+ TestHive.cacheTable("src")
+ }
+
+ createQueryTest("read from cached table",
+ "SELECT * FROM src LIMIT 1", reset = false)
+
+ test("check that table is cached and uncache") {
+ TestHive.table("src").queryExecution.analyzed match {
+ case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching
+ case noCache => fail(s"No cache node found in plan $noCache")
+ }
+ TestHive.uncacheTable("src")
+ }
+
+ createQueryTest("read from uncached table",
+ "SELECT * FROM src LIMIT 1", reset = false)
+
+ test("make sure table is uncached") {
+ TestHive.table("src").queryExecution.analyzed match {
+ case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) =>
+ fail(s"Table still cached after uncache: $cachePlan")
+ case noCache => // Table uncached successfully
+ }
+ }
+
+ test("correct error on uncache of non-cached table") {
+ intercept[IllegalArgumentException] {
+ TestHive.uncacheTable("src")
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala
new file mode 100644
index 0000000000000..8137f99b227f4
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.api.java
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.hive.TestHive
+
+// Implicits
+import scala.collection.JavaConversions._
+
+class JavaHiveSQLSuite extends FunSuite {
+ ignore("SELECT * FROM src") {
+ val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext)
+ // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM
+ val javaSqlCtx = new JavaHiveContext(javaCtx) {
+ override val sqlContext = TestHive
+ }
+
+ assert(
+ javaSqlCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) ===
+ TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index c7a350ef94edd..3cc4562a88d66 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -125,7 +125,7 @@ abstract class HiveComparisonTest
}
protected def prepareAnswer(
- hiveQuery: TestHive.type#SqlQueryExecution,
+ hiveQuery: TestHive.type#HiveQLQueryExecution,
answer: Seq[String]): Seq[String] = {
val orderedAnswer = hiveQuery.logical match {
// Clean out non-deterministic time schema info.
@@ -170,7 +170,7 @@ abstract class HiveComparisonTest
}
val installHooksCommand = "(?i)SET.*hooks".r
- def createQueryTest(testCaseName: String, sql: String) {
+ def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) {
// If test sharding is enable, skip tests that are not in the correct shard.
shardInfo.foreach {
case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return
@@ -227,8 +227,8 @@ abstract class HiveComparisonTest
try {
// MINOR HACK: You must run a query before calling reset the first time.
- TestHive.sql("SHOW TABLES")
- TestHive.reset()
+ TestHive.hql("SHOW TABLES")
+ if (reset) { TestHive.reset() }
val hiveCacheFiles = queryList.zipWithIndex.map {
case (queryString, i) =>
@@ -256,7 +256,7 @@ abstract class HiveComparisonTest
hiveCachedResults
} else {
- val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_))
+ val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_))
// Make sure we can at least parse everything before attempting hive execution.
hiveQueries.foreach(_.logical)
val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map {
@@ -295,14 +295,14 @@ abstract class HiveComparisonTest
fail(errorMessage)
}
}.toSeq
- TestHive.reset()
+ if (reset) { TestHive.reset() }
computedResults
}
// Run w/ catalyst
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
- val query = new TestHive.SqlQueryExecution(queryString)
+ val query = new TestHive.HiveQLQueryExecution(queryString)
try { (query, prepareAnswer(query, query.stringResult())) } catch {
case e: Exception =>
val errorMessage =
@@ -359,7 +359,7 @@ abstract class HiveComparisonTest
// When we encounter an error we check to see if the environment is still okay by running a simple query.
// If this fails then we halt testing since something must have gone seriously wrong.
try {
- new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult()
+ new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult()
TestHive.runSqlHive("SELECT key FROM src")
} catch {
case e: Exception =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index f74b0fbb97c83..f76e16bc1afc5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -42,6 +42,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"bucket_num_reducers",
"column_access_stats",
"concatenate_inherit_table_location",
+ "describe_pretty",
+ "describe_syntax",
+ "orc_ends_with_nulls",
// Setting a default property does not seem to get reset and thus changes the answer for many
// subsequent tests.
@@ -80,7 +83,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"index_auto_update",
"index_auto_self_join",
"index_stale.*",
- "type_cast_1",
"index_compression",
"index_bitmap_compression",
"index_auto_multiple",
@@ -237,9 +239,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"compute_stats_binary",
"compute_stats_boolean",
"compute_stats_double",
- "compute_stats_table",
+ "compute_stats_empty_table",
"compute_stats_long",
"compute_stats_string",
+ "compute_stats_table",
"convert_enum_to_string",
"correlationoptimizer11",
"correlationoptimizer15",
@@ -266,8 +269,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"desc_non_existent_tbl",
"describe_comment_indent",
"describe_database_json",
- "describe_pretty",
- "describe_syntax",
+ "describe_formatted_view_partitioned",
+ "describe_formatted_view_partitioned_json",
"describe_table_json",
"diff_part_input_formats",
"disable_file_format_check",
@@ -339,8 +342,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"input11_limit",
"input12",
"input12_hadoop20",
+ "input14",
"input19",
"input1_limit",
+ "input21",
"input22",
"input23",
"input24",
@@ -355,6 +360,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"input7",
"input8",
"input9",
+ "inputddl4",
+ "inputddl7",
+ "inputddl8",
"input_limit",
"input_part0",
"input_part1",
@@ -368,9 +376,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"input_part7",
"input_part8",
"input_part9",
- "inputddl4",
- "inputddl7",
- "inputddl8",
+ "input_testsequencefile",
+ "insert1",
+ "insert2_overwrite_partitions",
"insert_compressed",
"join0",
"join1",
@@ -385,6 +393,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"join17",
"join18",
"join19",
+ "join_1to1",
"join2",
"join20",
"join21",
@@ -400,6 +409,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"join30",
"join31",
"join32",
+ "join32_lessSize",
"join33",
"join34",
"join35",
@@ -415,13 +425,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"join7",
"join8",
"join9",
- "join_1to1",
"join_array",
"join_casesensitive",
"join_empty",
"join_filters",
"join_hive_626",
+ "join_map_ppr",
"join_nulls",
+ "join_rc",
"join_reorder2",
"join_reorder3",
"join_reorder4",
@@ -435,22 +446,32 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"literal_string",
"load_dyn_part7",
"load_file_with_space_in_the_name",
+ "loadpart1",
"louter_join_ppr",
"mapjoin_distinct",
"mapjoin_mapjoin",
"mapjoin_subquery",
"mapjoin_subquery2",
"mapjoin_test_outer",
+ "mapreduce1",
+ "mapreduce2",
"mapreduce3",
+ "mapreduce4",
+ "mapreduce5",
+ "mapreduce6",
"mapreduce7",
+ "mapreduce8",
"merge1",
"merge2",
"mergejoins",
"mergejoins_mixed",
+ "multigroupby_singlemr",
+ "multi_insert_gby",
+ "multi_insert_gby3",
+ "multi_insert_lateral_view",
+ "multi_join_union",
"multiMapJoin1",
"multiMapJoin2",
- "multi_join_union",
- "multigroupby_singlemr",
"noalias_subq1",
"nomore_ambiguous_table_col",
"nonblock_op_deduplicate",
@@ -466,16 +487,30 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"nullinput2",
"nullscript",
"optional_outer",
+ "orc_dictionary_threshold",
+ "orc_empty_files",
"order",
"order2",
"outer_join_ppr",
+ "parallel",
+ "parenthesis_star_by",
+ "partcols1",
"part_inherit_tbl_props",
"part_inherit_tbl_props_empty",
"part_inherit_tbl_props_with_star",
"partition_schema1",
+ "partition_serde_format",
"partition_varchar1",
+ "partition_wise_fileformat4",
+ "partition_wise_fileformat5",
+ "partition_wise_fileformat6",
+ "partition_wise_fileformat7",
+ "partition_wise_fileformat9",
"plan_json",
"ppd1",
+ "ppd2",
+ "ppd_clusterby",
+ "ppd_constant_expr",
"ppd_constant_where",
"ppd_gby",
"ppd_gby2",
@@ -491,6 +526,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"ppd_outer_join5",
"ppd_random",
"ppd_repeated_alias",
+ "ppd_transform",
"ppd_udf_col",
"ppd_union",
"ppr_allchildsarenull",
@@ -503,7 +539,15 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"query_with_semi",
"quote1",
"quote2",
+ "rcfile_columnar",
+ "rcfile_lazydecompress",
+ "rcfile_null_value",
+ "rcfile_toleratecorruptions",
+ "rcfile_union",
+ "reduce_deduplicate",
+ "reduce_deduplicate_exclude_gby",
"reduce_deduplicate_exclude_join",
+ "reducesink_dedup",
"rename_column",
"router_join_ppr",
"select_as_omitted",
@@ -531,6 +575,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"smb_mapjoin_3",
"smb_mapjoin_4",
"smb_mapjoin_5",
+ "smb_mapjoin_6",
+ "smb_mapjoin_7",
"smb_mapjoin_8",
"sort",
"sort_merge_join_desc_1",
@@ -541,21 +587,27 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"sort_merge_join_desc_6",
"sort_merge_join_desc_7",
"stats0",
+ "stats_aggregator_error_1",
"stats_empty_partition",
+ "stats_publisher_error_1",
"subq2",
"tablename_with_select",
+ "timestamp_comparison",
"touch",
+ "transform_ppr1",
+ "transform_ppr2",
+ "type_cast_1",
"type_widening",
"udaf_collect_set",
"udaf_corr",
"udaf_covar_pop",
"udaf_covar_samp",
+ "udaf_histogram_numeric",
+ "udf_10_trims",
"udf2",
"udf6",
+ "udf8",
"udf9",
- "udf_10_trims",
- "udf_E",
- "udf_PI",
"udf_abs",
"udf_acos",
"udf_add",
@@ -585,13 +637,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"udf_cos",
"udf_count",
"udf_date_add",
- "udf_date_sub",
"udf_datediff",
+ "udf_date_sub",
"udf_day",
"udf_dayofmonth",
"udf_degrees",
"udf_div",
"udf_double",
+ "udf_E",
"udf_exp",
"udf_field",
"udf_find_in_set",
@@ -631,6 +684,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"udf_nvl",
"udf_or",
"udf_parse_url",
+ "udf_PI",
"udf_positive",
"udf_pow",
"udf_power",
@@ -671,9 +725,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"udf_trim",
"udf_ucase",
"udf_upper",
+ "udf_variance",
"udf_var_pop",
"udf_var_samp",
- "udf_variance",
"udf_weekofyear",
"udf_when",
"udf_xpath",
@@ -703,8 +757,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"union27",
"union28",
"union29",
+ "union3",
"union30",
"union31",
+ "union33",
"union34",
"union4",
"union5",
@@ -714,6 +770,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest {
"union9",
"union_lateralview",
"union_ppr",
+ "union_remove_11",
"union_remove_3",
"union_remove_6",
"union_script",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 4b92d167a1263..a09667ac84b01 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -23,6 +23,16 @@ import org.apache.spark.sql.hive.TestHive._
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
*/
class HiveQuerySuite extends HiveComparisonTest {
+
+ test("Query expressed in SQL") {
+ assert(sql("SELECT 1").collect() === Array(Seq(1)))
+ }
+
+ test("Query expressed in HiveQL") {
+ hql("FROM src SELECT key").collect()
+ hiveql("FROM src SELECT key").collect()
+ }
+
createQueryTest("Simple Average",
"SELECT AVG(key) FROM src")
@@ -53,10 +63,8 @@ class HiveQuerySuite extends HiveComparisonTest {
createQueryTest("length.udf",
"SELECT length(\"test\") FROM src LIMIT 1")
- ignore("partitioned table scan") {
- createQueryTest("partitioned table scan",
- "SELECT ds, hr, key, value FROM srcpart")
- }
+ createQueryTest("partitioned table scan",
+ "SELECT ds, hr, key, value FROM srcpart")
createQueryTest("hash",
"SELECT hash('test') FROM src LIMIT 1")
@@ -135,7 +143,11 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v")
test("sampling") {
- sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
+ hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
}
+ test("SchemaRDD toString") {
+ hql("SHOW TABLES").toString
+ hql("SELECT * FROM src").toString
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index d77900ddc950c..8883e5b16d4da 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -48,7 +48,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
createQueryTest("attr",
"SELECT key FROM src a ORDER BY key LIMIT 1")
- createQueryTest("alias.*",
+ createQueryTest("alias.star",
"SELECT a.* FROM src a ORDER BY key LIMIT 1")
test("case insensitivity with scala reflection") {
@@ -56,7 +56,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil)
.registerAsTable("caseSensitivityTest")
- sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
+ hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
}
/**
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
index 1318ac1968dad..d9ccb93e23923 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -136,7 +136,7 @@ class PruningSuite extends HiveComparisonTest {
expectedScannedColumns: Seq[String],
expectedPartValues: Seq[Seq[String]]) = {
test(s"$testCaseName - pruning test") {
- val plan = new TestHive.SqlQueryExecution(sql).executedPlan
+ val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan
val actualOutputColumns = plan.output.map(_.name)
val (actualScannedColumns, actualPartValues) = plan.collect {
case p @ HiveTableScan(columns, relation, _) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
index 05ad85b622ac8..aade62eb8f84e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala
@@ -17,147 +17,138 @@
package org.apache.spark.sql.parquet
-import java.io.File
-
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.expressions.Row
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.getTempFilePath
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
+import org.apache.spark.sql.catalyst.types.{DataType, StringType, IntegerType}
+import org.apache.spark.sql.{parquet, SchemaRDD}
import org.apache.spark.sql.hive.TestHive
+import org.apache.spark.util.Utils
+
+// Implicits
+import org.apache.spark.sql.hive.TestHive._
class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
- val filename = getTempFilePath("parquettest").getCanonicalFile.toURI.toString
-
- // runs a SQL and optionally resolves one Parquet table
- def runQuery(
- querystr: String,
- tableName: Option[String] = None,
- filename: Option[String] = None): Array[Row] = {
-
- // call to resolve references in order to get CREATE TABLE AS to work
- val query = TestHive
- .parseSql(querystr)
- val finalQuery =
- if (tableName.nonEmpty && filename.nonEmpty)
- resolveParquetTable(tableName.get, filename.get, query)
- else
- query
- TestHive.executePlan(finalQuery)
- .toRdd
- .collect()
- }
- // stores a query output to a Parquet file
- def storeQuery(querystr: String, filename: String): Unit = {
- val query = WriteToFile(
- filename,
- TestHive.parseSql(querystr))
- TestHive
- .executePlan(query)
- .stringResult()
- }
+ val dirname = Utils.createTempDir()
- /**
- * TODO: This function is necessary as long as there is no notion of a Catalog for
- * Parquet tables. Once such a thing exists this functionality should be moved there.
- */
- def resolveParquetTable(tableName: String, filename: String, plan: LogicalPlan): LogicalPlan = {
- TestHive.loadTestTable("src") // may not be loaded now
- plan.transform {
- case relation @ UnresolvedRelation(databaseName, name, alias) =>
- if (name == tableName)
- ParquetRelation(tableName, filename)
- else
- relation
- case op @ InsertIntoCreatedTable(databaseName, name, child) =>
- if (name == tableName) {
- // note: at this stage the plan is not yet analyzed but Parquet needs to know the schema
- // and for that we need the child to be resolved
- val relation = ParquetRelation.create(
- filename,
- TestHive.analyzer(child),
- TestHive.sparkContext.hadoopConfiguration,
- Some(tableName))
- InsertIntoTable(
- relation.asInstanceOf[BaseRelation],
- Map.empty,
- child,
- overwrite = false)
- } else
- op
- }
- }
+ var testRDD: SchemaRDD = null
override def beforeAll() {
// write test data
- ParquetTestData.writeFile()
- // Override initial Parquet test table
- TestHive.catalog.registerTable(Some[String]("parquet"), "testsource", ParquetTestData.testData)
+ ParquetTestData.writeFile
+ testRDD = parquetFile(ParquetTestData.testDir.toString)
+ testRDD.registerAsTable("testsource")
}
override def afterAll() {
- ParquetTestData.testFile.delete()
+ Utils.deleteRecursively(ParquetTestData.testDir)
+ Utils.deleteRecursively(dirname)
+ reset() // drop all tables that were registered as part of the tests
}
+ // in case tests are failing we delete before and after each test
override def beforeEach() {
- new File(filename).getAbsoluteFile.delete()
+ Utils.deleteRecursively(dirname)
}
override def afterEach() {
- new File(filename).getAbsoluteFile.delete()
+ Utils.deleteRecursively(dirname)
}
test("SELECT on Parquet table") {
- val rdd = runQuery("SELECT * FROM parquet.testsource")
+ val rdd = hql("SELECT * FROM testsource").collect()
assert(rdd != null)
assert(rdd.forall(_.size == 6))
}
test("Simple column projection + filter on Parquet table") {
- val rdd = runQuery("SELECT myboolean, mylong FROM parquet.testsource WHERE myboolean=true")
+ val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect()
assert(rdd.size === 5, "Filter returned incorrect number of rows")
assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value")
}
- test("Converting Hive to Parquet Table via WriteToFile") {
- storeQuery("SELECT * FROM src", filename)
- val rddOne = runQuery("SELECT * FROM src").sortBy(_.getInt(0))
- val rddTwo = runQuery("SELECT * from ptable", Some("ptable"), Some(filename)).sortBy(_.getInt(0))
+ test("Converting Hive to Parquet Table via saveAsParquetFile") {
+ hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath)
+ parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
+ val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0))
+ val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0))
compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String"))
}
test("INSERT OVERWRITE TABLE Parquet table") {
- storeQuery("SELECT * FROM parquet.testsource", filename)
- runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename))
- runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename))
- val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename))
- val rddOrig = runQuery("SELECT * FROM parquet.testsource")
- compareRDDs(rddOrig, rddCopy, "parquet.testsource", ParquetTestData.testSchemaFieldNames)
+ hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath)
+ parquetFile(dirname.getAbsolutePath).registerAsTable("ptable")
+ // let's do three overwrites for good measure
+ hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
+ hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
+ hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect()
+ val rddCopy = hql("SELECT * FROM ptable").collect()
+ val rddOrig = hql("SELECT * FROM testsource").collect()
+ assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??")
+ compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames)
}
- test("CREATE TABLE AS Parquet table") {
- runQuery("CREATE TABLE ptable AS SELECT * FROM src", Some("ptable"), Some(filename))
- val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename))
+ test("CREATE TABLE of Parquet table") {
+ createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
+ .registerAsTable("tmp")
+ val rddCopy =
+ hql("INSERT INTO TABLE tmp SELECT * FROM src")
+ .collect()
.sortBy[Int](_.apply(0) match {
case x: Int => x
case _ => 0
})
- val rddOrig = runQuery("SELECT * FROM src").sortBy(_.getInt(0))
+ val rddOrig = hql("SELECT * FROM src")
+ .collect()
+ .sortBy(_.getInt(0))
compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String"))
}
+ test("Appending to Parquet table") {
+ createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
+ .registerAsTable("tmpnew")
+ hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
+ hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
+ hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect()
+ val rddCopies = hql("SELECT * FROM tmpnew").collect()
+ val rddOrig = hql("SELECT * FROM src").collect()
+ assert(rddCopies.size === 3 * rddOrig.size, "number of copied rows via INSERT INTO did not match correct number")
+ }
+
+ test("Appending to and then overwriting Parquet table") {
+ createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType))
+ .registerAsTable("tmp")
+ hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
+ hql("INSERT INTO TABLE tmp SELECT * FROM src").collect()
+ hql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect()
+ val rddCopies = hql("SELECT * FROM tmp").collect()
+ val rddOrig = hql("SELECT * FROM src").collect()
+ assert(rddCopies.size === rddOrig.size, "INSERT OVERWRITE did not actually overwrite")
+ }
+
private def compareRDDs(rddOne: Array[Row], rddTwo: Array[Row], tableName: String, fieldNames: Seq[String]) {
var counter = 0
(rddOne, rddTwo).zipped.foreach {
(a,b) => (a,b).zipped.toArray.zipWithIndex.foreach {
- case ((value_1:Array[Byte], value_2:Array[Byte]), index) =>
- assert(new String(value_1) === new String(value_2), s"table $tableName row $counter field ${fieldNames(index)} don't match")
case ((value_1, value_2), index) =>
assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match")
}
counter = counter + 1
}
}
+
+ /**
+ * Creates an empty SchemaRDD backed by a ParquetRelation.
+ *
+ * TODO: since this is so experimental it is better to have it here and not
+ * in SQLContext. Also note that when creating new AttributeReferences
+ * one needs to take care not to create duplicate Attribute ID's.
+ */
+ private def createParquetFile(path: String, schema: (Tuple2[String, DataType])*): SchemaRDD = {
+ val attributes = schema.map(t => new AttributeReference(t._1, t._2)())
+ new SchemaRDD(
+ TestHive,
+ parquet.ParquetRelation.createEmpty(path, attributes, sparkContext.hadoopConfiguration))
+ }
}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 1953cc6883378..93b1c5a37aff9 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -96,7 +96,6 @@
org.apache.maven.plugins
maven-jar-plugin
- 2.2
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index fde46705d89fb..d3339063cc079 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -153,7 +153,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
def validate() {
this.synchronized {
assert(batchDuration != null, "Batch duration has not been set")
- //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration +
+ // assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration +
// " is very low")
assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute")
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 72d566f3cb0a5..a4e236c65ff86 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -531,7 +531,7 @@ object StreamingContext extends Logging {
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to StreamingContext.
*/
- def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls)
+ def jarOfClass(cls: Class[_]): Seq[String] = SparkContext.jarOfClass(cls)
private[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index a85cd04c9319c..bb2f492d06a00 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -49,7 +49,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
* Print the first ten elements of each RDD generated in this DStream. This is an output
* operator, so this DStream will be registered as an output stream and there materialized.
*/
- def print() = dstream.print()
+ def print(): Unit = {
+ dstream.print()
+ }
/**
* Return a new DStream in which each RDD has a single element generated by counting each RDD
@@ -401,7 +403,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
* Enable periodic checkpointing of RDDs of this DStream.
* @param interval Time interval after which generated RDD will be checkpointed
*/
- def checkpoint(interval: Duration) = {
+ def checkpoint(interval: Duration): DStream[T] = {
dstream.checkpoint(interval)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index a3463657ef0b7..c800602d0959b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -477,25 +477,33 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Start the execution of the streams.
*/
- def start() = ssc.start()
+ def start(): Unit = {
+ ssc.start()
+ }
/**
* Wait for the execution to stop. Any exceptions that occurs during the execution
* will be thrown in this thread.
*/
- def awaitTermination() = ssc.awaitTermination()
+ def awaitTermination(): Unit = {
+ ssc.awaitTermination()
+ }
/**
* Wait for the execution to stop. Any exceptions that occurs during the execution
* will be thrown in this thread.
* @param timeout time to wait in milliseconds
*/
- def awaitTermination(timeout: Long) = ssc.awaitTermination(timeout)
+ def awaitTermination(timeout: Long): Unit = {
+ ssc.awaitTermination(timeout)
+ }
/**
* Stop the execution of the streams. Will stop the associated JavaSparkContext as well.
*/
- def stop() = ssc.stop()
+ def stop(): Unit = {
+ ssc.stop()
+ }
/**
* Stop the execution of the streams.
@@ -589,7 +597,7 @@ object JavaStreamingContext {
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to StreamingContext.
*/
- def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray
+ def jarOfClass(cls: Class[_]): Array[String] = SparkContext.jarOfClass(cls).toArray
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 6bff56a9d332a..d48b51aa69565 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -503,14 +503,18 @@ abstract class DStream[T: ClassTag] (
* 'this' DStream will be registered as an output stream and therefore materialized.
*/
@deprecated("use foreachRDD", "0.9.0")
- def foreach(foreachFunc: RDD[T] => Unit) = this.foreachRDD(foreachFunc)
+ def foreach(foreachFunc: RDD[T] => Unit): Unit = {
+ this.foreachRDD(foreachFunc)
+ }
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* 'this' DStream will be registered as an output stream and therefore materialized.
*/
@deprecated("use foreachRDD", "0.9.0")
- def foreach(foreachFunc: (RDD[T], Time) => Unit) = this.foreachRDD(foreachFunc)
+ def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = {
+ this.foreachRDD(foreachFunc)
+ }
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
index ca0a8ae47864d..b334d68bf9910 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -78,7 +78,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
override def checkpoint(interval: Duration): DStream[(K, V)] = {
super.checkpoint(interval)
- //reducedStream.checkpoint(interval)
+ // reducedStream.checkpoint(interval)
this
}
@@ -128,7 +128,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
// Cogroup the reduced RDDs and merge the reduced values
val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(K, _)]]],
partitioner)
- //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _
+ // val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _
val numOldValues = oldRDDs.size
val numNewValues = newRDDs.size
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index 9d8889b655356..5f7d3ba26c656 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -64,7 +64,6 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
- //logDebug("Generating state RDD for time " + validTime)
Some(stateRDD)
}
case None => { // If parent RDD does not exist
@@ -97,11 +96,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
val groupedRDD = parentRDD.groupByKey(partitioner)
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
- //logDebug("Generating state RDD for time " + validTime + " (first)")
+ // logDebug("Generating state RDD for time " + validTime + " (first)")
Some(sessionRDD)
}
case None => { // If parent RDD does not exist, then nothing to do!
- //logDebug("Not generating state RDD (no previous state, no parent)")
+ // logDebug("Not generating state RDD (no previous state, no parent)")
None
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
index 4e8d07fe921fb..7f3cd2f8eb1fd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
@@ -39,17 +39,19 @@ case class BatchInfo(
* was submitted to the streaming scheduler. Essentially, it is
* `processingStartTime` - `submissionTime`.
*/
- def schedulingDelay = processingStartTime.map(_ - submissionTime)
+ def schedulingDelay: Option[Long] = processingStartTime.map(_ - submissionTime)
/**
* Time taken for the all jobs of this batch to finish processing from the time they started
* processing. Essentially, it is `processingEndTime` - `processingStartTime`.
*/
- def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption
+ def processingDelay: Option[Long] = processingEndTime.zip(processingStartTime)
+ .map(x => x._1 - x._2).headOption
/**
* Time taken for all the jobs of this batch to finish processing from the time they
* were submitted. Essentially, it is `processingDelay` + `schedulingDelay`.
*/
- def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption
+ def totalDelay: Option[Long] = schedulingDelay.zip(processingDelay)
+ .map(x => x._1 + x._2).headOption
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 0784e562ac719..25739956cb889 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -252,7 +252,7 @@ class CheckpointSuite extends TestSuiteBase {
ssc.start()
// Create files and advance manual clock to process them
- //var clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ // var clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
Thread.sleep(1000)
for (i <- Seq(1, 2, 3)) {
Files.write(i + "\n", new File(testDir, i.toString), Charset.forName("UTF-8"))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 74e73ebb342fe..389b23d4d5e4b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -144,8 +144,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
}
-
- test("actor input stream") {
+ // TODO: This test makes assumptions about Thread.sleep() and is flaky
+ ignore("actor input stream") {
// Start the server
val testServer = new TestServer()
val port = testServer.port
@@ -154,7 +154,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
// Set up the streaming context and input streams
val ssc = new StreamingContext(conf, batchDuration)
val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor",
- StorageLevel.MEMORY_AND_DISK) //Had to pass the local value of port to prevent from closing over entire scope
+ // Had to pass the local value of port to prevent from closing over entire scope
+ StorageLevel.MEMORY_AND_DISK)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(networkStream, outputBuffer)
def output = outputBuffer.flatMap(x => x)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
index 7d18a0fcf7ba8..9ebf7b484f421 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala
@@ -36,8 +36,9 @@ class RateLimitedOutputStreamSuite extends FunSuite {
val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000)
val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }
- // We accept anywhere from 4.0 to 4.99999 seconds since the value is rounded down.
- assert(SECONDS.convert(elapsedNs, NANOSECONDS) === 4)
+ val seconds = SECONDS.convert(elapsedNs, NANOSECONDS)
+ assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.")
+ assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.")
assert(underlying.toString("UTF-8") === data)
}
}
diff --git a/tools/pom.xml b/tools/pom.xml
index 11433e596f5b0..ae2ba64e07c21 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -55,6 +55,14 @@
spark-streaming_${scala.binary.version}
${project.version}
+
+ org.scala-lang
+ scala-reflect
+
+
+ org.scala-lang
+ scala-compiler
+
org.scalatest
scalatest_${scala.binary.version}
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 71a64ecf5879a..0179b0600c61f 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -167,6 +167,9 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
object Client {
def main(argStrings: Array[String]) {
+ println("WARNING: This client is deprecated and will be removed in a future version of Spark.")
+ println("Use ./bin/spark-submit with \"--master yarn\"")
+
// Set an env variable indicating we are running in YARN mode.
// Note that anything with SPARK prefix gets propagated to all (remote) processes
System.setProperty("SPARK_YARN_MODE", "true")
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 981e8b05f602d..3469b7decedf6 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -81,7 +81,8 @@ class ExecutorRunnable(
credentials.writeTokenStorageToStream(dob)
ctx.setContainerTokens(ByteBuffer.wrap(dob.getData()))
- val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores)
+ val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores,
+ localResources.contains(ClientBase.LOG4J_PROP))
logInfo("Setting up executor with commands: " + commands)
ctx.setCommands(commands)
diff --git a/yarn/common/src/main/resources/log4j-spark-container.properties b/yarn/common/src/main/resources/log4j-spark-container.properties
new file mode 100644
index 0000000000000..a1e37a0be27dd
--- /dev/null
+++ b/yarn/common/src/main/resources/log4j-spark-container.properties
@@ -0,0 +1,24 @@
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License. See accompanying LICENSE file.
+
+# Set everything to be logged to the console
+log4j.rootCategory=INFO, console
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.target=System.err
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+
+# Settings to quiet third party logs that are too verbose
+log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index c565f2dde24fc..3e4c739e34fe9 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -63,7 +63,10 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) {
userClass = value
args = tail
- case ("--args") :: value :: tail =>
+ case ("--args" | "--arg") :: value :: tail =>
+ if (args(0) == "--args") {
+ println("--args is deprecated. Use --arg instead.")
+ }
userArgsBuffer += value
args = tail
@@ -146,8 +149,8 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) {
"Options:\n" +
" --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" +
" --class CLASS_NAME Name of your application's main class (required)\n" +
- " --args ARGS Arguments to be passed to your application's main class.\n" +
- " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --arg ARGS Argument to be passed to your application's main class.\n" +
+ " Multiple invocations are possible, each will be passed in order.\n" +
" --num-executors NUM Number of executors to start (Default: 2)\n" +
" --executor-cores NUM Number of cores for the executors (Default: 1).\n" +
" --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
index 57e5761cba896..eb42922aea228 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -139,7 +139,6 @@ trait ClientBase extends Logging {
} else if (srcHost != null && dstHost == null) {
return false
}
- //check for ports
if (srcUri.getPort() != dstUri.getPort()) {
false
} else {
@@ -267,11 +266,11 @@ trait ClientBase extends Logging {
localResources: HashMap[String, LocalResource],
stagingDir: String): HashMap[String, String] = {
logInfo("Setting up the launch environment")
- val log4jConfLocalRes = localResources.getOrElse(ClientBase.LOG4J_PROP, null)
val env = new HashMap[String, String]()
- ClientBase.populateClasspath(yarnConf, sparkConf, log4jConfLocalRes != null, env)
+ ClientBase.populateClasspath(yarnConf, sparkConf, localResources.contains(ClientBase.LOG4J_PROP),
+ env)
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_STAGING_DIR") = stagingDir
env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
@@ -345,15 +344,13 @@ trait ClientBase extends Logging {
JAVA_OPTS += " " + env("SPARK_JAVA_OPTS")
}
- // Command for the ApplicationMaster
- var javaCommand = "java"
- val javaHome = System.getenv("JAVA_HOME")
- if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
- javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
+ if (!localResources.contains(ClientBase.LOG4J_PROP)) {
+ JAVA_OPTS += " " + YarnSparkHadoopUtil.getLoggingArgsForContainerCommandLine()
}
+ // Command for the ApplicationMaster
val commands = List[String](
- javaCommand +
+ Environment.JAVA_HOME.$() + "/bin/java" +
" -server " +
JAVA_OPTS +
" " + args.amClass +
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
index 68cda0f1c9f8b..9b7f1fca96c6d 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -157,7 +157,7 @@ class ClientDistributedCacheManager() extends Logging {
def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = {
val fs = FileSystem.get(uri, conf)
val current = new Path(uri.getPath())
- //the leaf level file should be readable by others
+ // the leaf level file should be readable by others
if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
return false
}
@@ -177,7 +177,7 @@ class ClientDistributedCacheManager() extends Logging {
statCache: Map[URI, FileStatus]): Boolean = {
var current = path
while (current != null) {
- //the subdirs in the path should have execute permissions for others
+ // the subdirs in the path should have execute permissions for others
if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) {
return false
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
index da0a6f74efcd5..b3696c5fe7183 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
@@ -50,7 +50,8 @@ trait ExecutorRunnableUtil extends Logging {
slaveId: String,
hostname: String,
executorMemory: Int,
- executorCores: Int) = {
+ executorCores: Int,
+ userSpecifiedLogFile: Boolean) = {
// Extra options for the JVM
var JAVA_OPTS = ""
// Set the JVM memory
@@ -63,6 +64,10 @@ trait ExecutorRunnableUtil extends Logging {
JAVA_OPTS += " -Djava.io.tmpdir=" +
new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
+ if (!userSpecifiedLogFile) {
+ JAVA_OPTS += " " + YarnSparkHadoopUtil.getLoggingArgsForContainerCommandLine()
+ }
+
// Commenting it out for now - so that people can refer to the properties if required. Remove
// it once cpuset version is pushed out.
// The context is, default gc for server class machines end up using all cores to do gc - hence
@@ -88,13 +93,8 @@ trait ExecutorRunnableUtil extends Logging {
}
*/
- var javaCommand = "java"
- val javaHome = System.getenv("JAVA_HOME")
- if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
- javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
- }
-
- val commands = List[String](javaCommand +
+ val commands = List[String](
+ Environment.JAVA_HOME.$() + "/bin/java" +
" -server " +
// Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
// Not killing the task leaves various aspects of the executor and (to some extent) the jvm in
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 4c6e1dcd6dac3..314a7550ada71 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -22,6 +22,7 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.api.ApplicationConstants
import org.apache.hadoop.conf.Configuration
import org.apache.spark.deploy.SparkHadoopUtil
@@ -67,3 +68,9 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
}
}
+
+object YarnSparkHadoopUtil {
+ def getLoggingArgsForContainerCommandLine(): String = {
+ "-Dlog4j.configuration=log4j-spark-container.properties"
+ }
+}
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index d1f13e3c369ed..161918859e7c4 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -33,11 +33,12 @@ private[spark] class YarnClientSchedulerBackend(
var client: Client = null
var appId: ApplicationId = null
- private[spark] def addArg(optionName: String, optionalParam: String, arrayBuf: ArrayBuffer[String]) {
- Option(System.getenv(optionalParam)) foreach {
- optParam => {
- arrayBuf += (optionName, optParam)
- }
+ private[spark] def addArg(optionName: String, envVar: String, sysProp: String,
+ arrayBuf: ArrayBuffer[String]) {
+ if (System.getProperty(sysProp) != null) {
+ arrayBuf += (optionName, System.getProperty(sysProp))
+ } else if (System.getenv(envVar) != null) {
+ arrayBuf += (optionName, System.getenv(envVar))
}
}
@@ -56,22 +57,24 @@ private[spark] class YarnClientSchedulerBackend(
"--am-class", "org.apache.spark.deploy.yarn.ExecutorLauncher"
)
- // process any optional arguments, use the defaults already defined in ClientArguments
- // if things aren't specified
- Map("SPARK_MASTER_MEMORY" -> "--driver-memory",
- "SPARK_DRIVER_MEMORY" -> "--driver-memory",
- "SPARK_WORKER_INSTANCES" -> "--num-executors",
- "SPARK_WORKER_MEMORY" -> "--executor-memory",
- "SPARK_WORKER_CORES" -> "--executor-cores",
- "SPARK_EXECUTOR_INSTANCES" -> "--num-executors",
- "SPARK_EXECUTOR_MEMORY" -> "--executor-memory",
- "SPARK_EXECUTOR_CORES" -> "--executor-cores",
- "SPARK_YARN_QUEUE" -> "--queue",
- "SPARK_YARN_APP_NAME" -> "--name",
- "SPARK_YARN_DIST_FILES" -> "--files",
- "SPARK_YARN_DIST_ARCHIVES" -> "--archives")
- .foreach { case (optParam, optName) => addArg(optName, optParam, argsArrayBuf) }
-
+ // process any optional arguments, given either as environment variables
+ // or system properties. use the defaults already defined in ClientArguments
+ // if things aren't specified. system properties override environment
+ // variables.
+ List(("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"),
+ ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"),
+ ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.worker.instances"),
+ ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"),
+ ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"),
+ ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"),
+ ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"),
+ ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"),
+ ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"),
+ ("--name", "SPARK_YARN_APP_NAME", "spark.app.name"),
+ ("--files", "SPARK_YARN_DIST_FILES", "spark.yarn.dist.files"),
+ ("--archives", "SPARK_YARN_DIST_ARCHIVES", "spark.yarn.dist.archives"))
+ .foreach { case (optName, envVar, sysProp) => addArg(optName, envVar, sysProp, argsArrayBuf) }
+
logDebug("ClientArguments called with: " + argsArrayBuf)
val args = new ClientArguments(argsArrayBuf.toArray, conf)
client = new Client(args, conf)
diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
index 458df4fa3cd99..80b57d1355a3a 100644
--- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
+++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -99,7 +99,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
- //add another one and verify both there and order correct
+ // add another one and verify both there and order correct
val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
null, new Path("/tmp/testing2"))
val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 35e31760c1f02..3342cb65edcd1 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -167,6 +167,12 @@
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
+
+
+
+ ../common/src/main/resources
+
+
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 837b7e12cb0de..77eb1276a0c4e 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -173,6 +173,9 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
object Client {
def main(argStrings: Array[String]) {
+ println("WARNING: This client is deprecated and will be removed in a future version of Spark.")
+ println("Use ./bin/spark-submit with \"--master yarn\"")
+
// Set an env variable indicating we are running in YARN mode.
// Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes -
// see Client#setupLaunchEnv().
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 53c403f7d0913..81d9d1b5c9280 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -78,7 +78,8 @@ class ExecutorRunnable(
credentials.writeTokenStorageToStream(dob)
ctx.setTokens(ByteBuffer.wrap(dob.getData()))
- val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores)
+ val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores,
+ localResources.contains(ClientBase.LOG4J_PROP))
logInfo("Setting up executor with commands: " + commands)
ctx.setCommands(commands)