diff --git a/README.md b/README.md index 6211a5889a3f5..f87e07aa5cc90 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,13 @@ # Apache Spark -Lightning-Fast Cluster Computing - +Spark is a fast and general cluster computing system for Big Data. It provides +high-level APIs in Scala, Java, and Python, and an optimized engine that +supports general computation graphs for data analysis. It also supports a +rich set of higher-level tools including Spark SQL for SQL and structured +data processing, MLLib for machine learning, GraphX for graph processing, +and Spark Streaming. + + ## Online Documentation @@ -69,29 +76,28 @@ can be run using: Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. -You can change the version by setting the `SPARK_HADOOP_VERSION` environment -when building Spark. +You can change the version by setting `-Dhadoop.version` when building Spark. For Apache Hadoop versions 1.x, Cloudera CDH MRv1, and other Hadoop versions without YARN, use: # Apache Hadoop 1.2.1 - $ SPARK_HADOOP_VERSION=1.2.1 sbt/sbt assembly + $ sbt/sbt -Dhadoop.version=1.2.1 assembly # Cloudera CDH 4.2.0 with MapReduce v1 - $ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt assembly + $ sbt/sbt -Dhadoop.version=2.0.0-mr1-cdh4.2.0 assembly For Apache Hadoop 2.2.X, 2.1.X, 2.0.X, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions -with YARN, also set `SPARK_YARN=true`: +with YARN, also set `-Pyarn`: # Apache Hadoop 2.0.5-alpha - $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly + $ sbt/sbt -Dhadoop.version=2.0.5-alpha -Pyarn assembly # Cloudera CDH 4.2.0 with MapReduce v2 - $ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_YARN=true sbt/sbt assembly + $ sbt/sbt -Dhadoop.version=2.0.0-cdh4.2.0 -Pyarn assembly # Apache Hadoop 2.2.X and newer - $ SPARK_HADOOP_VERSION=2.2.0 SPARK_YARN=true sbt/sbt assembly + $ sbt/sbt -Dhadoop.version=2.2.0 -Pyarn assembly When developing a Spark application, specify the Hadoop version by adding the "hadoop-client" artifact to your project's dependencies. For example, if you're diff --git a/assembly/pom.xml b/assembly/pom.xml index 0c60b66c3daca..4f6aade133db7 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -32,6 +32,7 @@ pom + assembly scala-${scala.binary.version} spark-assembly-${project.version}-hadoop${hadoop.version}.jar ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} diff --git a/bagel/pom.xml b/bagel/pom.xml index c8e39a415af28..90c4b095bb611 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-bagel_2.10 + + bagel + jar Spark Project Bagel http://spark.apache.org/ diff --git a/bin/spark-class b/bin/spark-class index 04fa52c6756b1..3f6beca5becf0 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -110,9 +110,9 @@ export JAVA_OPTS TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" -if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then +if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the SBT build - export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar` + export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar` fi if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the Maven build diff --git a/core/pom.xml b/core/pom.xml index 6abf8480d5da0..1054cec4d77bb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-core_2.10 + + core + jar Spark Project Core http://spark.apache.org/ @@ -111,6 +114,10 @@ org.xerial.snappy snappy-java + + net.jpountz.lz4 + lz4 + com.twitter chill_${scala.binary.version} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 7448af87fcf38..445110d63e184 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -81,7 +81,9 @@ table.sortable thead { span.kill-link { margin-right: 2px; + margin-left: 20px; color: gray; + float: right; } span.kill-link a { diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 59fdf659c9e11..1d640579efe77 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -56,8 +56,8 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) while (iter.hasNext) { - val (k, v) = iter.next() - combiners.insert(k, v) + val pair = iter.next() + combiners.insert(pair._1, pair._2) } // TODO: Make this non optional in a future release Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) @@ -85,8 +85,8 @@ case class Aggregator[K, V, C] ( } else { val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) while (iter.hasNext) { - val (k, c) = iter.next() - combiners.insert(k, c) + val pair = iter.next() + combiners.insert(pair._1, pair._2) } // TODO: Make this non optional in a future release Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8819e73d17fb2..8052499ab7526 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1531,7 +1531,16 @@ object SparkContext extends Logging { throw new SparkException("YARN mode not available ?", e) } } - val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + val backend = try { + val clazz = + Class.forName("org.apache.spark.scheduler.cluster.YarnClusterSchedulerBackend") + val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) + cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] + } catch { + case e: Exception => { + throw new SparkException("YARN mode not available ?", e) + } + } scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index df42d679b4699..8f0c5e78416c2 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -88,10 +89,7 @@ case class ExceptionFailure( stackTrace: Array[StackTraceElement], metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = { - val stackTraceString = if (stackTrace == null) "null" else stackTrace.mkString("\n") - s"$className ($description}\n$stackTraceString" - } + override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 885c6829a2d72..8ca731038e528 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -92,8 +92,8 @@ private[spark] object TestUtils { def createCompiledClass(className: String, destDir: File, value: String = ""): File = { val compiler = ToolProvider.getSystemJavaCompiler val sourceFile = new JavaSourceFromString(className, - "public class " + className + " { @Override public String toString() { " + - "return \"" + value + "\";}}") + "public class " + className + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + value + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 76956f6a345d1..15fd30e65761d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -106,23 +106,23 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { * Actually get the broadcasted value. Concrete implementations of Broadcast class must * define their own way to get the value. */ - private[spark] def getValue(): T + protected def getValue(): T /** * Actually unpersist the broadcasted value on the executors. Concrete implementations of * Broadcast class must define their own logic to unpersist their own data. */ - private[spark] def doUnpersist(blocking: Boolean) + protected def doUnpersist(blocking: Boolean) /** * Actually destroy all data and metadata related to this broadcast variable. * Implementation of Broadcast class must define their own logic to destroy their own * state. */ - private[spark] def doDestroy(blocking: Boolean) + protected def doDestroy(blocking: Boolean) /** Check if this broadcast is valid. If not valid, exception is thrown. */ - private[spark] def assertValid() { + protected def assertValid() { if (!_isValid) { throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index c88be6aba6901..8f8a0b11f9f2e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -39,7 +39,7 @@ private[spark] class BroadcastManager( synchronized { if (!initialized) { val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4f6cabaff2b99..487456467b23b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -40,9 +40,9 @@ private[spark] class HttpBroadcast[T: ClassTag]( @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def getValue = value_ + override protected def getValue() = value_ - val blockId = BroadcastBlockId(id) + private val blockId = BroadcastBlockId(id) /* * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster @@ -60,14 +60,14 @@ private[spark] class HttpBroadcast[T: ClassTag]( /** * Remove all persisted state associated with this HTTP broadcast on the executors. */ - def doUnpersist(blocking: Boolean) { + override protected def doUnpersist(blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) } /** * Remove all persisted state associated with this HTTP broadcast on the executors and driver. */ - def doDestroy(blocking: Boolean) { + override protected def doDestroy(blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) } @@ -102,7 +102,7 @@ private[spark] class HttpBroadcast[T: ClassTag]( } } -private[spark] object HttpBroadcast extends Logging { +private[broadcast] object HttpBroadcast extends Logging { private var initialized = false private var broadcastDir: File = null private var compress: Boolean = false @@ -160,7 +160,7 @@ private[spark] object HttpBroadcast extends Logging { def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) - def write(id: Long, value: Any) { + private def write(id: Long, value: Any) { val file = getFile(id) val out: OutputStream = { if (compress) { @@ -176,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging { files += file } - def read[T: ClassTag](id: Long): T = { + private def read[T: ClassTag](id: Long): T = { logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index d5a031e2bbb59..c7ef02d572a19 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -27,21 +27,21 @@ import org.apache.spark.{SecurityManager, SparkConf} * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. */ class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { HttpBroadcast.initialize(isDriver, conf, securityMgr) } - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) - def stop() { HttpBroadcast.stop() } + override def stop() { HttpBroadcast.stop() } /** * Remove all persisted state associated with the HTTP broadcast with the given ID. * @param removeFromDriver Whether to remove state from the driver * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { HttpBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 734de37ba115d..86731b684f441 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -20,7 +20,6 @@ package org.apache.spark.broadcast import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.reflect.ClassTag -import scala.math import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} @@ -49,19 +48,19 @@ private[spark] class TorrentBroadcast[T: ClassTag]( @transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def getValue = value_ + override protected def getValue() = value_ - val broadcastId = BroadcastBlockId(id) + private val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } - @transient var arrayOfBlocks: Array[TorrentBlock] = null - @transient var totalBlocks = -1 - @transient var totalBytes = -1 - @transient var hasBlocks = 0 + @transient private var arrayOfBlocks: Array[TorrentBlock] = null + @transient private var totalBlocks = -1 + @transient private var totalBytes = -1 + @transient private var hasBlocks = 0 if (!isLocal) { sendBroadcast() @@ -70,7 +69,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( /** * Remove all persisted state associated with this Torrent broadcast on the executors. */ - def doUnpersist(blocking: Boolean) { + override protected def doUnpersist(blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) } @@ -78,11 +77,11 @@ private[spark] class TorrentBroadcast[T: ClassTag]( * Remove all persisted state associated with this Torrent broadcast on the executors * and driver. */ - def doDestroy(blocking: Boolean) { + override protected def doDestroy(blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } - def sendBroadcast() { + private def sendBroadcast() { val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes @@ -159,7 +158,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( hasBlocks = 0 } - def receiveBroadcast(): Boolean = { + private def receiveBroadcast(): Boolean = { // Receive meta-info about the size of broadcast data, // the number of chunks it is divided into, etc. val metaId = BroadcastBlockId(id, "meta") @@ -211,7 +210,7 @@ private[spark] class TorrentBroadcast[T: ClassTag]( } -private[spark] object TorrentBroadcast extends Logging { +private[broadcast] object TorrentBroadcast extends Logging { private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null @@ -272,17 +271,19 @@ private[spark] object TorrentBroadcast extends Logging { * Remove all persisted blocks associated with this torrent broadcast on the executors. * If removeFromDriver is true, also remove these persisted blocks on the driver. */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + } } } -private[spark] case class TorrentBlock( +private[broadcast] case class TorrentBlock( blockID: Int, byteArray: Array[Byte]) extends Serializable -private[spark] case class TorrentInfo( +private[broadcast] case class TorrentInfo( @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 1de8396a0e17f..ad0f701d7a98f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -28,21 +28,21 @@ import org.apache.spark.{SecurityManager, SparkConf} */ class TorrentBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { TorrentBroadcast.initialize(isDriver, conf) } - def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) - def stop() { TorrentBroadcast.stop() } + override def stop() { TorrentBroadcast.stop() } /** * Remove all persisted state associated with the torrent broadcast with the given ID. * @param removeFromDriver Whether to remove state from the driver. * @param blocking Whether to block until unbroadcasted */ - def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b050dccb6d57f..3d8373d8175ee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -27,25 +27,39 @@ import org.apache.spark.executor.ExecutorURLClassLoader import org.apache.spark.util.Utils /** - * Scala code behind the spark-submit script. The script handles setting up the classpath with - * relevant Spark dependencies and provides a layer over the different cluster managers and deploy - * modes that Spark supports. + * Main gateway of launching a Spark application. + * + * This program handles setting up the classpath with relevant Spark dependencies and provides + * a layer over the different cluster managers and deploy modes that Spark supports. */ object SparkSubmit { + + // Cluster managers private val YARN = 1 private val STANDALONE = 2 private val MESOS = 4 private val LOCAL = 8 private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL - private var clusterManager: Int = LOCAL + // Deploy modes + private val CLIENT = 1 + private val CLUSTER = 2 + private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - /** - * Special primary resource names that represent shells rather than application jars. - */ + // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" + // Exposed for testing + private[spark] var exitFn: () => Unit = () => System.exit(-1) + private[spark] var printStream: PrintStream = System.err + private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) + private[spark] def printErrorAndExit(str: String) = { + printStream.println("Error: " + str) + printStream.println("Run with --help for usage help or --verbose for debug output") + exitFn() + } + def main(args: Array[String]) { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { @@ -55,88 +69,80 @@ object SparkSubmit { launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) } - // Exposed for testing - private[spark] var printStream: PrintStream = System.err - private[spark] var exitFn: () => Unit = () => System.exit(-1) - - private[spark] def printErrorAndExit(str: String) = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn() - } - private[spark] def printWarning(str: String) = printStream.println("Warning: " + str) - /** - * @return a tuple containing the arguments for the child, a list of classpath - * entries for the child, a list of system properties, a list of env vars - * and the main class for the child + * @return a tuple containing + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a list of system properties and env vars, and + * (4) the main class for the child */ private[spark] def createLaunchEnv(args: SparkSubmitArguments) : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = { - if (args.master.startsWith("local")) { - clusterManager = LOCAL - } else if (args.master.startsWith("yarn")) { - clusterManager = YARN - } else if (args.master.startsWith("spark")) { - clusterManager = STANDALONE - } else if (args.master.startsWith("mesos")) { - clusterManager = MESOS - } else { - printErrorAndExit("Master must start with yarn, mesos, spark, or local") - } - - // Because "yarn-cluster" and "yarn-client" encapsulate both the master - // and deploy mode, we have some logic to infer the master and deploy mode - // from each other if only one is specified, or exit early if they are at odds. - if (args.deployMode == null && - (args.master == "yarn-standalone" || args.master == "yarn-cluster")) { - args.deployMode = "cluster" - } - if (args.deployMode == "cluster" && args.master == "yarn-client") { - printErrorAndExit("Deploy mode \"cluster\" and master \"yarn-client\" are not compatible") - } - if (args.deployMode == "client" && - (args.master == "yarn-standalone" || args.master == "yarn-cluster")) { - printErrorAndExit("Deploy mode \"client\" and master \"" + args.master - + "\" are not compatible") - } - if (args.deployMode == "cluster" && args.master.startsWith("yarn")) { - args.master = "yarn-cluster" - } - if (args.deployMode != "cluster" && args.master.startsWith("yarn")) { - args.master = "yarn-client" - } - - val deployOnCluster = Option(args.deployMode).getOrElse("client") == "cluster" - val childClasspath = new ArrayBuffer[String]() + // Values to return val childArgs = new ArrayBuffer[String]() + val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() var childMainClass = "" - val isPython = args.isPython - val isYarnCluster = clusterManager == YARN && deployOnCluster + // Set the cluster manager + val clusterManager: Int = args.master match { + case m if m.startsWith("yarn") => YARN + case m if m.startsWith("spark") => STANDALONE + case m if m.startsWith("mesos") => MESOS + case m if m.startsWith("local") => LOCAL + case _ => printErrorAndExit("Master must start with yarn, spark, mesos, or local"); -1 + } - // For mesos, only client mode is supported - if (clusterManager == MESOS && deployOnCluster) { - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + // Set the deploy mode; default is client mode + var deployMode: Int = args.deployMode match { + case "client" | null => CLIENT + case "cluster" => CLUSTER + case _ => printErrorAndExit("Deploy mode must be either client or cluster"); -1 } - // For standalone, only client mode is supported - if (clusterManager == STANDALONE && deployOnCluster) { - printErrorAndExit("Cluster deploy mode is currently not supported for standalone clusters.") + // Because "yarn-cluster" and "yarn-client" encapsulate both the master + // and deploy mode, we have some logic to infer the master and deploy mode + // from each other if only one is specified, or exit early if they are at odds. + if (clusterManager == YARN) { + if (args.master == "yarn-standalone") { + printWarning("\"yarn-standalone\" is deprecated. Use \"yarn-cluster\" instead.") + args.master = "yarn-cluster" + } + (args.master, args.deployMode) match { + case ("yarn-cluster", null) => + deployMode = CLUSTER + case ("yarn-cluster", "client") => + printErrorAndExit("Client deploy mode is not compatible with master \"yarn-cluster\"") + case ("yarn-client", "cluster") => + printErrorAndExit("Cluster deploy mode is not compatible with master \"yarn-client\"") + case (_, mode) => + args.master = "yarn-" + Option(mode).getOrElse("client") + } + + // Make sure YARN is included in our build if we're trying to use it + if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) { + printErrorAndExit( + "Could not load YARN classes. " + + "This copy of Spark may not have been compiled with YARN support.") + } } - // For shells, only client mode is applicable - if (isShell(args.primaryResource) && deployOnCluster) { - printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + // The following modes are not supported or applicable + (clusterManager, deployMode) match { + case (MESOS, CLUSTER) => + printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (STANDALONE, CLUSTER) => + printErrorAndExit("Cluster deploy mode is currently not supported for Standalone clusters.") + case (_, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python applications.") + case (_, CLUSTER) if isShell(args.primaryResource) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") + case _ => } // If we're running a python app, set the main class to our specific python runner - if (isPython) { - if (deployOnCluster) { - printErrorAndExit("Cluster deploy mode is currently not supported for python.") - } + if (args.isPython) { if (args.primaryResource == PYSPARK_SHELL) { args.mainClass = "py4j.GatewayServer" args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0") @@ -152,120 +158,115 @@ object SparkSubmit { sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",") } - // If we're deploying into YARN, use yarn.Client as a wrapper around the user class - if (!deployOnCluster) { - childMainClass = args.mainClass - if (isUserJar(args.primaryResource)) { - childClasspath += args.primaryResource - } - } else if (clusterManager == YARN) { - childMainClass = "org.apache.spark.deploy.yarn.Client" - childArgs += ("--jar", args.primaryResource) - childArgs += ("--class", args.mainClass) - } - - // Make sure YARN is included in our build if we're trying to use it - if (clusterManager == YARN) { - if (!Utils.classIsLoadable("org.apache.spark.deploy.yarn.Client") && !Utils.isTesting) { - printErrorAndExit("Could not load YARN classes. " + - "This copy of Spark may not have been compiled with YARN support.") - } - } - // Special flag to avoid deprecation warnings at the client sysProps("SPARK_SUBMIT") = "true" // A list of rules to map each argument to system properties or command-line options in // each deploy mode; we iterate through these below val options = List[OptionAssigner]( - OptionAssigner(args.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, false, sysProp = "spark.app.name"), - OptionAssigner(args.name, YARN, true, clOption = "--name", sysProp = "spark.app.name"), - OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, true, + + // All cluster managers + OptionAssigner(args.master, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.master"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.app.name"), + OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), + + // Standalone cluster only + OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), + OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), + + // Yarn client only + OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), + OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), + OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), + OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), + OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + + // Yarn cluster only + OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name", sysProp = "spark.app.name"), + OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), + OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), + OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), + OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), + OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), + OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), + OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), + OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), + + // Other options + OptionAssigner(args.driverExtraClassPath, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraClassPath"), - OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, true, + OptionAssigner(args.driverExtraJavaOptions, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraJavaOptions"), - OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, true, + OptionAssigner(args.driverExtraLibraryPath, STANDALONE | YARN, CLUSTER, sysProp = "spark.driver.extraLibraryPath"), - OptionAssigner(args.driverMemory, YARN, true, clOption = "--driver-memory"), - OptionAssigner(args.driverMemory, STANDALONE, true, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, true, clOption = "--cores"), - OptionAssigner(args.queue, YARN, true, clOption = "--queue"), - OptionAssigner(args.queue, YARN, false, sysProp = "spark.yarn.queue"), - OptionAssigner(args.numExecutors, YARN, true, clOption = "--num-executors"), - OptionAssigner(args.numExecutors, YARN, false, sysProp = "spark.executor.instances"), - OptionAssigner(args.executorMemory, YARN, true, clOption = "--executor-memory"), - OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, false, + OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, CLIENT, sysProp = "spark.executor.memory"), - OptionAssigner(args.executorCores, YARN, true, clOption = "--executor-cores"), - OptionAssigner(args.executorCores, YARN, false, sysProp = "spark.executor.cores"), - OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, false, + OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, CLIENT, sysProp = "spark.cores.max"), - OptionAssigner(args.files, YARN, false, sysProp = "spark.yarn.dist.files"), - OptionAssigner(args.files, YARN, true, clOption = "--files"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, false, sysProp = "spark.files"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, true, sysProp = "spark.files"), - OptionAssigner(args.archives, YARN, false, sysProp = "spark.yarn.dist.archives"), - OptionAssigner(args.archives, YARN, true, clOption = "--archives"), - OptionAssigner(args.jars, YARN, true, clOption = "--addJars"), - OptionAssigner(args.jars, ALL_CLUSTER_MGRS, false, sysProp = "spark.jars") + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + sysProp = "spark.files") ) - // For client mode make any added jars immediately visible on the classpath - if (args.jars != null && !deployOnCluster) { - for (jar <- args.jars.split(",")) { - childClasspath += jar + // In client mode, launch the application main class directly + // In addition, add the main application jar and any added jars (if any) to the classpath + if (deployMode == CLIENT) { + childMainClass = args.mainClass + if (isUserJar(args.primaryResource)) { + childClasspath += args.primaryResource } + if (args.jars != null) { childClasspath ++= args.jars.split(",") } + if (args.childArgs != null) { childArgs ++= args.childArgs } } + // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { - if (opt.value != null && deployOnCluster == opt.deployOnCluster && + if (opt.value != null && + (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { - if (opt.clOption != null) { - childArgs += (opt.clOption, opt.value) - } - if (opt.sysProp != null) { - sysProps.put(opt.sysProp, opt.value) - } + if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } + if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } } } // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python files, the primary resource is already distributed as a regular file - if (!isYarnCluster && !isPython) { - var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq()) + val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + if (!isYarnCluster && !args.isPython) { + var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) } sysProps.put("spark.jars", jars.mkString(",")) } - // Standalone cluster specific configurations - if (deployOnCluster && clusterManager == STANDALONE) { + // In standalone-cluster mode, use Client as a wrapper around the user class + if (clusterManager == STANDALONE && deployMode == CLUSTER) { + childMainClass = "org.apache.spark.deploy.Client" if (args.supervise) { childArgs += "--supervise" } - childMainClass = "org.apache.spark.deploy.Client" childArgs += "launch" childArgs += (args.master, args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } } - // Arguments to be passed to user program - if (args.childArgs != null) { - if (!deployOnCluster || clusterManager == STANDALONE) { - childArgs ++= args.childArgs - } else if (clusterManager == YARN) { - for (arg <- args.childArgs) { - childArgs += ("--arg", arg) - } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class + if (clusterManager == YARN && deployMode == CLUSTER) { + childMainClass = "org.apache.spark.deploy.yarn.Client" + childArgs += ("--jar", args.primaryResource) + childArgs += ("--class", args.mainClass) + if (args.childArgs != null) { + args.childArgs.foreach { arg => childArgs += ("--arg", arg) } } } // Read from default spark properties, if any for ((k, v) <- args.getDefaultSparkProperties) { - if (!sysProps.contains(k)) sysProps(k) = v + sysProps.getOrElseUpdate(k, v) } (childArgs, childClasspath, sysProps, childMainClass) @@ -364,6 +365,6 @@ object SparkSubmit { private[spark] case class OptionAssigner( value: String, clusterManager: Int, - deployOnCluster: Boolean, + deployMode: Int, clOption: String = null, sysProp: String = null) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index a304102a49086..bb1fcc8190fe4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.master +import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date @@ -30,7 +31,6 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.hadoop.fs.FileSystem import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} @@ -57,6 +57,7 @@ private[spark] class Master( def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) + val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") @@ -72,9 +73,7 @@ private[spark] class Master( val waitingApps = new ArrayBuffer[ApplicationInfo] val completedApps = new ArrayBuffer[ApplicationInfo] var nextAppNumber = 0 - val appIdToUI = new HashMap[String, SparkUI] - val fileSystemsUsed = new HashSet[FileSystem] val drivers = new HashSet[DriverInfo] val completedDrivers = new ArrayBuffer[DriverInfo] @@ -159,7 +158,6 @@ private[spark] class Master( recoveryCompletionTask.cancel() } webUi.stop() - fileSystemsUsed.foreach(_.close()) masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() @@ -644,10 +642,7 @@ private[spark] class Master( waitingApps -= app // If application events are logged, use them to rebuild the UI - if (!rebuildSparkUI(app)) { - // Avoid broken links if the UI is not reconstructed - app.desc.appUiUrl = "" - } + rebuildSparkUI(app) for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) @@ -669,29 +664,47 @@ private[spark] class Master( */ def rebuildSparkUI(app: ApplicationInfo): Boolean = { val appName = app.desc.name - val eventLogDir = app.desc.eventLogDir.getOrElse { return false } + val eventLogDir = app.desc.eventLogDir.getOrElse { + // Event logging is not enabled for this application + app.desc.appUiUrl = "/history/not-found" + return false + } val fileSystem = Utils.getHadoopFileSystem(eventLogDir) val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem) val eventLogPaths = eventLogInfo.logPaths val compressionCodec = eventLogInfo.compressionCodec - if (!eventLogPaths.isEmpty) { - try { - val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) - val ui = new SparkUI( - new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id) - replayBus.replay() - app.desc.appUiUrl = ui.basePath - appIdToUI(app.id) = ui - webUi.attachSparkUI(ui) - return true - } catch { - case e: Exception => - logError("Exception in replaying log for application %s (%s)".format(appName, app.id), e) - } - } else { - logWarning("Application %s (%s) has no valid logs: %s".format(appName, app.id, eventLogDir)) + + if (eventLogPaths.isEmpty) { + // Event logging is enabled for this application, but no event logs are found + val title = s"Application history not found (${app.id})" + var msg = s"No event logs found for application $appName in $eventLogDir." + logWarning(msg) + msg += " Did you specify the correct logging directory?" + msg = URLEncoder.encode(msg, "UTF-8") + app.desc.appUiUrl = s"/history/not-found?msg=$msg&title=$title" + return false + } + + try { + val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) + val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id) + replayBus.replay() + appIdToUI(app.id) = ui + webUi.attachSparkUI(ui) + // Application UI is successfully rebuilt, so link the Master UI to it + app.desc.appUiUrl = ui.basePath + true + } catch { + case e: Exception => + // Relay exception message to application UI page + val title = s"Application history load error (${app.id})" + val exception = URLEncoder.encode(Utils.exceptionString(e), "UTF-8") + var msg = s"Exception in replaying log for application $appName!" + logError(msg, e) + msg = URLEncoder.encode(msg, "UTF-8") + app.desc.appUiUrl = s"/history/not-found?msg=$msg&exception=$exception&title=$title" + false } - false } /** Generate a new app ID given a app's submission date */ @@ -744,11 +757,16 @@ private[spark] class Master( case Some(driver) => logInfo(s"Removing driver: $driverId") drivers -= driver + if (completedDrivers.size >= RETAINED_DRIVERS) { + val toRemove = math.max(RETAINED_DRIVERS / 10, 1) + completedDrivers.trimStart(toRemove) + } completedDrivers += driver persistenceEngine.removeDriver(driver) driver.state = finalState driver.exception = exception driver.worker.foreach(w => w.removeDriver(driver)) + schedule() case None => logWarning(s"Asked to remove unknown driver: $driverId") } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 34fa1429c86de..4588c130ef439 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -28,7 +28,7 @@ import org.json4s.JValue import org.apache.spark.deploy.{ExecutorState, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorInfo -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala new file mode 100644 index 0000000000000..d8daff3e7fb9c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.net.URLDecoder +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[spark] class HistoryNotFoundPage(parent: MasterWebUI) + extends WebUIPage("history/not-found") { + + /** + * Render a page that conveys failure in loading application history. + * + * This accepts 3 HTTP parameters: + * msg = message to display to the user + * title = title of the page + * exception = detailed description of the exception in loading application history (if any) + * + * Parameters "msg" and "exception" are assumed to be UTF-8 encoded. + */ + def render(request: HttpServletRequest): Seq[Node] = { + val titleParam = request.getParameter("title") + val msgParam = request.getParameter("msg") + val exceptionParam = request.getParameter("exception") + + // If no parameters are specified, assume the user did not enable event logging + val defaultTitle = "Event logging is not enabled" + val defaultContent = +
+
+ No event logs were found for this application! To + enable event logging, + set spark.eventLog.enabled to true and + spark.eventLog.dir to the directory to which your + event logs are written. +
+
+ + val title = Option(titleParam).getOrElse(defaultTitle) + val content = Option(msgParam) + .map { msg => URLDecoder.decode(msg, "UTF-8") } + .map { msg => +
+
{msg}
+
++ + Option(exceptionParam) + .map { e => URLDecoder.decode(e, "UTF-8") } + .map { e =>
{e}
} + .getOrElse(Seq.empty) + }.getOrElse(defaultContent) + + UIUtils.basicSparkPage(content, title) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index a18b39fc95d64..16aa0493370dd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.deploy.master.Master import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.AkkaUtils /** * Web UI server for the standalone master. @@ -38,6 +38,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) /** Initialize all components of the server. */ def initialize() { attachPage(new ApplicationPage(this)) + attachPage(new HistoryNotFoundPage(this)) attachPage(new MasterPage(this)) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) master.masterMetricsSystem.getServletHandlers.foreach(attachHandler) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8d31bd05fdbec..b455c9fcf4bd6 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -71,7 +71,7 @@ private[spark] class CoarseGrainedExecutorBackend( val ser = SparkEnv.get.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) + executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask) } case KillTask(taskId, _, interruptThread) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 4d3ba11633bf5..b16133b20cc02 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -107,8 +107,9 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] - def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - val tr = new TaskRunner(context, taskId, serializedTask) + def launchTask( + context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) { + val tr = new TaskRunner(context, taskId, taskName, serializedTask) runningTasks.put(taskId, tr) threadPool.execute(tr) } @@ -135,14 +136,15 @@ private[spark] class Executor( localDirs } - class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) + class TaskRunner( + execBackend: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { @volatile private var killed = false @volatile private var task: Task[Any] = _ def kill(interruptThread: Boolean) { - logInfo("Executor is trying to kill task " + taskId) + logInfo(s"Executor is trying to kill $taskName (TID $taskId)") killed = true if (task != null) { task.kill(interruptThread) @@ -154,7 +156,7 @@ private[spark] class Executor( SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() - logInfo("Running task ID " + taskId) + logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 @@ -207,25 +209,30 @@ private[spark] class Executor( val accumUpdates = Accumulators.values - val directResult = new DirectTaskResult(valueBytes, accumUpdates, - task.metrics.getOrElse(null)) + val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull) val serializedDirectResult = ser.serialize(directResult) - logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) - val serializedResult = { - if (serializedDirectResult.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - logInfo("Storing result for " + taskId + " in local BlockManager") + val resultSize = serializedDirectResult.limit + + // directSend = sending directly back to the driver + val (serializedResult, directSend) = { + if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - ser.serialize(new IndirectTaskResult[Any](blockId)) + (ser.serialize(new IndirectTaskResult[Any](blockId)), false) } else { - logInfo("Sending result for " + taskId + " directly to driver") - serializedDirectResult + (serializedDirectResult, true) } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - logInfo("Finished task ID " + taskId) + + if (directSend) { + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + } else { + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason @@ -233,7 +240,7 @@ private[spark] class Executor( } case _: TaskKilledException | _: InterruptedException if task.killed => { - logInfo("Executor killed task " + taskId) + logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) } @@ -241,7 +248,7 @@ private[spark] class Executor( // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to // the default uncaught exception handler, which will terminate the Executor. - logError("Exception in task ID " + taskId, t) + logError(s"Exception in $taskName (TID $taskId)", t) val serviceTime = System.currentTimeMillis() - taskStart val metrics = attemptedTask.flatMap(t => t.metrics) @@ -249,7 +256,7 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) + val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index 2232e6237bf26..a42c8b43bbf7f 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -64,7 +64,7 @@ private[spark] class MesosExecutorBackend if (executor == null) { logError("Received launchTask but executor was null") } else { - executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer) + executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ac73288442a74..5d59e00636ee6 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -75,7 +75,9 @@ class TaskMetrics extends Serializable { /** * If this task reads from shuffle output, metrics on getting shuffle data will be collected here */ - var shuffleReadMetrics: Option[ShuffleReadMetrics] = None + private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None + + def shuffleReadMetrics = _shuffleReadMetrics /** * If this task writes to shuffle output, metrics on the written shuffle data will be collected @@ -87,6 +89,22 @@ class TaskMetrics extends Serializable { * Storage statuses of any blocks that have been updated as a result of this task. */ var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None + + /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */ + def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized { + _shuffleReadMetrics match { + case Some(existingMetrics) => + existingMetrics.shuffleFinishTime = math.max( + existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime) + existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime + existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched + existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched + existingMetrics.totalBlocksFetched += newMetrics.totalBlocksFetched + existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead + case None => + _shuffleReadMetrics = Some(newMetrics) + } + } } private[spark] object TaskMetrics { diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 4b0fe1ab82999..1b66218d86dd9 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -20,6 +20,7 @@ package org.apache.spark.io import java.io.{InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf @@ -55,7 +56,28 @@ private[spark] object CompressionCodec { ctor.newInstance(conf).asInstanceOf[CompressionCodec] } - val DEFAULT_COMPRESSION_CODEC = classOf[LZFCompressionCodec].getName + val DEFAULT_COMPRESSION_CODEC = classOf[SnappyCompressionCodec].getName +} + + +/** + * :: DeveloperApi :: + * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. + * Block size can be configured by `spark.io.compression.lz4.block.size`. + * + * Note: The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. + */ +@DeveloperApi +class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { + + override def compressedOutputStream(s: OutputStream): OutputStream = { + val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + new LZ4BlockOutputStream(s, blockSize) + } + + override def compressedInputStream(s: InputStream): InputStream = new LZ4BlockInputStream(s) } @@ -81,7 +103,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by spark.io.compression.snappy.block.size. + * Block size can be configured by `spark.io.compression.snappy.block.size`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 5951865e56c9d..aca235a62a6a8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -170,17 +170,22 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { val newCombiner = Array.fill(numRdds)(new CoGroup) - value match { case (v, depNum) => newCombiner(depNum) += v } + newCombiner(value._2) += value._1 newCombiner } val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = (combiner, value) => { - value match { case (v, depNum) => combiner(depNum) += v } + combiner(value._2) += value._1 combiner } val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = (combiner1, combiner2) => { - combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } + var depNum = 0 + while (depNum < numRdds) { + combiner1(depNum) ++= combiner2(depNum) + depNum += 1 + } + combiner1 } new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( createCombiner, mergeValue, mergeCombiners) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index c45b759f007cc..e7221e3032c11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -258,7 +258,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc val pgroup = PartitionGroup(nxt_replica) groupArr += pgroup addPartToPGroup(nxt_part, pgroup) - groupHash += (nxt_replica -> (ArrayBuffer(pgroup))) // list in case we have multiple + groupHash.put(nxt_replica, ArrayBuffer(pgroup)) // list in case we have multiple numCreated += 1 } } @@ -267,7 +267,7 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc var (nxt_replica, nxt_part) = rotIt.next() val pgroup = PartitionGroup(nxt_replica) groupArr += pgroup - groupHash.get(nxt_replica).get += pgroup + groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup var tries = 0 while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part nxt_part = rotIt.next()._2 diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 041028514399b..e521612ffc27c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -140,8 +140,8 @@ class HadoopRDD[K, V]( // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456) - conf.synchronized { + // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { val newJobConf = new JobConf(conf) initLocalJobConfFuncOpt.map(f => f(newJobConf)) HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) @@ -246,6 +246,9 @@ class HadoopRDD[K, V]( } private[spark] object HadoopRDD { + /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ + val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala index 2bc47eb9fcd74..a60952eee5901 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala @@ -28,6 +28,6 @@ class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U) override val partitioner = firstParent[Product2[K, U]].partitioner override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = { - firstParent[Product2[K, V]].iterator(split, context).map { case Product2(k ,v) => (k, f(v)) } + firstParent[Product2[K, V]].iterator(split, context).map { pair => (pair._1, f(pair._2)) } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index fc9beb166befe..a6b920467283e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -125,7 +125,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) zeroBuffer.get(zeroArray) lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() - def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) + val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray)) combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner) } @@ -171,7 +171,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // When deserializing, use a lazy val to create just one instance of the serializer per task lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance() - def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) + val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray)) combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner) } @@ -214,22 +214,22 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("reduceByKeyLocally() does not support array keys") } - def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { + val reducePartition = (iter: Iterator[(K, V)]) => { val map = new JHashMap[K, V] - iter.foreach { case (k, v) => - val old = map.get(k) - map.put(k, if (old == null) v else func(old, v)) + iter.foreach { pair => + val old = map.get(pair._1) + map.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) } Iterator(map) - } + } : Iterator[JHashMap[K, V]] - def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { - m2.foreach { case (k, v) => - val old = m1.get(k) - m1.put(k, if (old == null) v else func(old, v)) + val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { + m2.foreach { pair => + val old = m1.get(pair._1) + m1.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) } m1 - } + } : JHashMap[K, V] self.mapPartitions(reducePartition).reduce(mergeMaps) } @@ -353,19 +353,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[PairRDDFunctions.reduceByKey]] or [[PairRDDFunctions.combineByKey]] - * will provide much better performance. + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = { // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. - def createCombiner(v: V) = ArrayBuffer(v) - def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2 + val createCombiner = (v: V) => ArrayBuffer(v) + val mergeValue = (buf: ArrayBuffer[V], v: V) => buf += v + val mergeCombiners = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => c1 ++ c2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false) + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) bufs.mapValues(_.toIterable) } @@ -373,9 +373,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[PairRDDFunctions.reduceByKey]] or [[PairRDDFunctions.combineByKey]] - * will provide much better performance. + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = { groupByKey(new HashPartitioner(numPartitions)) @@ -401,9 +401,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - for (v <- vs; w <- ws) yield (v, w) - } + this.cogroup(other, partitioner).flatMapValues( pair => + for (v <- pair._1; w <- pair._2) yield (v, w) + ) } /** @@ -413,11 +413,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * partition the output RDD. */ def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (ws.isEmpty) { - vs.map(v => (v, None)) + this.cogroup(other, partitioner).flatMapValues { pair => + if (pair._2.isEmpty) { + pair._1.map(v => (v, None)) } else { - for (v <- vs; w <- ws) yield (v, Some(w)) + for (v <- pair._1; w <- pair._2) yield (v, Some(w)) } } } @@ -430,11 +430,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) */ def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) : RDD[(K, (Option[V], W))] = { - this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - if (vs.isEmpty) { - ws.map(w => (None, w)) + this.cogroup(other, partitioner).flatMapValues { pair => + if (pair._1.isEmpty) { + pair._2.map(w => (None, w)) } else { - for (v <- vs; w <- ws) yield (Some(v), w) + for (v <- pair._1; w <- pair._2) yield (Some(v), w) } } } @@ -462,9 +462,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[PairRDDFunctions.reduceByKey]] or [[PairRDDFunctions.combineByKey]] - * will provide much better performance, + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupByKey(): RDD[(K, Iterable[V])] = { groupByKey(defaultPartitioner(self)) @@ -535,7 +535,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val data = self.collect() val map = new mutable.HashMap[K, V] map.sizeHint(data.length) - data.foreach { case (k, v) => map.put(k, v) } + data.foreach { pair => map.put(pair._1, pair._2) } map } @@ -572,10 +572,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val cg = new CoGroupedRDD[K](Seq(self, other1, other2, other3), partitioner) cg.mapValues { case Seq(vs, w1s, w2s, w3s) => - (vs.asInstanceOf[Seq[V]], - w1s.asInstanceOf[Seq[W1]], - w2s.asInstanceOf[Seq[W2]], - w3s.asInstanceOf[Seq[W3]]) + (vs.asInstanceOf[Seq[V]], + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]], + w3s.asInstanceOf[Seq[W3]]) } } @@ -589,8 +589,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - cg.mapValues { case Seq(vs, ws) => - (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) + cg.mapValues { case Seq(vs, w1s) => + (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W]]) } } @@ -606,8 +606,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) cg.mapValues { case Seq(vs, w1s, w2s) => (vs.asInstanceOf[Seq[V]], - w1s.asInstanceOf[Seq[W1]], - w2s.asInstanceOf[Seq[W2]]) + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]]) } } @@ -710,14 +710,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) self.partitioner match { case Some(p) => val index = p.getPartition(key) - def process(it: Iterator[(K, V)]): Seq[V] = { + val process = (it: Iterator[(K, V)]) => { val buf = new ArrayBuffer[V] - for ((k, v) <- it if k == key) { - buf += v + for (pair <- it if pair._1 == key) { + buf += pair._2 } buf - } - val res = self.context.runJob(self, process _, Array(index), false) + } : Seq[V] + val res = self.context.runJob(self, process, Array(index), false) res(0) case None => self.filter(_._1 == key).map(_._2).collect() @@ -840,7 +840,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = { + val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt @@ -858,22 +858,21 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { while (iter.hasNext) { - val (k, v) = iter.next() - writer.write(k, v) + val pair = iter.next() + writer.write(pair._1, pair._2) } - } - finally { + } finally { writer.close(hadoopContext) } committer.commitTask(hadoopContext) - return 1 - } + 1 + } : Int val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) - self.context.runJob(self, writeShard _) + self.context.runJob(self, writeShard) jobCommitter.commitJob(jobTaskContext) } @@ -912,7 +911,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = new SparkHadoopWriter(hadoopConf) writer.preSetup() - def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) { + val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt @@ -921,19 +920,18 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.open() try { var count = 0 - while(iter.hasNext) { + while (iter.hasNext) { val record = iter.next() count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } - } - finally { + } finally { writer.close() } writer.commit() } - self.context.runJob(self, writeToFile _) + self.context.runJob(self, writeToFile) writer.commitJob() } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4e841bc992bff..88a918aebf763 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -328,7 +328,7 @@ abstract class RDD[T: ClassTag]( : RDD[T] = { if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ - def distributePartition(index: Int, items: Iterator[T]): Iterator[(Int, T)] = { + val distributePartition = (index: Int, items: Iterator[T]) => { var position = (new Random(index)).nextInt(numPartitions) items.map { t => // Note that the hash code of the key will just be the key itself. The HashPartitioner @@ -336,7 +336,7 @@ abstract class RDD[T: ClassTag]( position = position + 1 (position, t) } - } + } : Iterator[(Int, T)] // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( @@ -509,6 +509,10 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements * mapping to that key. + * + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = groupBy[K](f, defaultPartitioner(this)) @@ -516,6 +520,10 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. + * + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K, numPartitions: Int)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = groupBy(f, new HashPartitioner(numPartitions)) @@ -523,6 +531,10 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped items. Each group consists of a key and a sequence of elements * mapping to that key. + * + * Note: This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] + * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) : RDD[(K, Iterable[T])] = { @@ -907,19 +919,19 @@ abstract class RDD[T: ClassTag]( throw new SparkException("countByValue() does not support arrays") } // TODO: This should perhaps be distributed by default. - def countPartition(iter: Iterator[T]): Iterator[OpenHashMap[T,Long]] = { + val countPartition = (iter: Iterator[T]) => { val map = new OpenHashMap[T,Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } Iterator(map) - } - def mergeMaps(m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]): OpenHashMap[T,Long] = { + }: Iterator[OpenHashMap[T,Long]] + val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => { m2.foreach { case (key, value) => m1.changeValue(key, value, _ + value) } m1 - } + }: OpenHashMap[T,Long] val myResult = mapPartitions(countPartition).reduce(mergeMaps) // Convert to a Scala mutable map val mutableResult = scala.collection.mutable.Map[T,Long]() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f72bfde572c96..ede3c7d9f01ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -816,7 +816,6 @@ class DAGScheduler( } event.reason match { case Success => - logInfo("Completed " + task) if (event.accumUpdates != null) { // TODO: fail the stage if the accumulator update fails... Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 6a6d8e609bc39..e41e0a9841691 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,4 +30,5 @@ private[spark] trait SchedulerBackend { def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = throw new UnsupportedOperationException + def isReady(): Boolean = true } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 29de0453ac19a..ca0595f35143e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -84,6 +84,8 @@ class TaskInfo( } } + def id: String = s"$index.$attempt" + def duration: Long = { if (!finished) { throw new UnsupportedOperationException("duration() called on unfinished task") diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 5ed2803d76afc..be3673c48eda8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -88,6 +88,8 @@ private[spark] class TaskSchedulerImpl( // in turn is used to decide when we can attain data locality on a given host private val executorsByHost = new HashMap[String, HashSet[String]] + protected val hostsByRack = new HashMap[String, HashSet[String]] + private val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into @@ -145,6 +147,10 @@ private[spark] class TaskSchedulerImpl( } } + override def postStartHook() { + waitBackendReady() + } + override def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") @@ -219,6 +225,9 @@ private[spark] class TaskSchedulerImpl( executorAdded(o.executorId, o.host) newExecAvail = true } + for (rack <- getRackForHost(o.host)) { + hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host + } } // Randomly shuffle offers to avoid always placing tasks on the same set of workers. @@ -414,6 +423,12 @@ private[spark] class TaskSchedulerImpl( execs -= executorId if (execs.isEmpty) { executorsByHost -= host + for (rack <- getRackForHost(host); hosts <- hostsByRack.get(rack)) { + hosts -= host + if (hosts.isEmpty) { + hostsByRack -= rack + } + } } executorIdToHost -= executorId rootPool.executorLost(executorId, host) @@ -431,12 +446,27 @@ private[spark] class TaskSchedulerImpl( executorsByHost.contains(host) } + def hasHostAliveOnRack(rack: String): Boolean = synchronized { + hostsByRack.contains(rack) + } + def isExecutorAlive(execId: String): Boolean = synchronized { activeExecutorIds.contains(execId) } // By default, rack is unknown def getRackForHost(value: String): Option[String] = None + + private def waitBackendReady(): Unit = { + if (backend.isReady) { + return + } + while (!backend.isReady) { + synchronized { + this.wait(100) + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 059cc9085a2e7..8b5e8cb802a45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -26,8 +26,7 @@ import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted, - SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} +import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{Clock, SystemClock} @@ -52,8 +51,8 @@ private[spark] class TaskSetManager( val taskSet: TaskSet, val maxTaskFailures: Int, clock: Clock = SystemClock) - extends Schedulable with Logging -{ + extends Schedulable with Logging { + val conf = sched.sc.conf /* @@ -191,7 +190,9 @@ private[spark] class TaskSetManager( addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) for (rack <- sched.getRackForHost(loc.host)) { addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - hadAliveLocations = true + if(sched.hasHostAliveOnRack(rack)){ + hadAliveLocations = true + } } } @@ -401,14 +402,11 @@ private[spark] class TaskSetManager( // Found a task; do some bookkeeping and return a task description val task = tasks(index) val taskId = sched.newTaskId() - // Figure out whether this should count as a preferred launch - logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format( - taskSet.id, index, taskId, execId, host, taskLocality)) // Do various bookkeeping copiesRunning(index) += 1 val attemptNum = taskAttempts(index).size - val info = new TaskInfo( - taskId, index, attemptNum + 1, curTime, execId, host, taskLocality, speculative) + val info = new TaskInfo(taskId, index, attemptNum, curTime, + execId, host, taskLocality, speculative) taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) // Update our locality level for delay scheduling @@ -427,11 +425,15 @@ private[spark] class TaskSetManager( s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") } - val timeTaken = clock.getTime() - startTime addRunningTask(taskId) - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %s:%d".format(taskSet.id, index) + + // We used to log the time it takes to serialize the task, but task size is already + // a good proxy to task serialization time. + // val timeTaken = clock.getTime() - startTime + val taskName = s"task ${info.id} in stage ${taskSet.id}" + logInfo("Starting %s (TID %d, %s, %s, %d bytes)".format( + taskName, taskId, host, taskLocality, serializedTask.limit)) + sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) } @@ -490,19 +492,19 @@ private[spark] class TaskSetManager( info.markSuccessful() removeRunningTask(tid) sched.dagScheduler.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { tasksSuccessful += 1 - logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( - tid, info.duration, info.host, tasksSuccessful, numTasks)) + logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( + info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks)) // Mark successful and stop if all the tasks have succeeded. successful(index) = true if (tasksSuccessful == numTasks) { isZombie = true } } else { - logInfo("Ignorning task-finished event for TID " + tid + " because task " + - index + " has already completed successfully") + logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + + " because task " + index + " has already completed successfully") } failedExecutors.remove(index) maybeFinishTaskSet() @@ -521,14 +523,13 @@ private[spark] class TaskSetManager( info.markFailed() val index = info.index copiesRunning(index) -= 1 - if (!isZombie) { - logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - } var taskMetrics : TaskMetrics = null - var failureReason: String = null + + val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + + reason.asInstanceOf[TaskFailedReason].toErrorString reason match { case fetchFailed: FetchFailed => - logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + logWarning(failureReason) if (!successful(index)) { successful(index) = true tasksSuccessful += 1 @@ -536,23 +537,17 @@ private[spark] class TaskSetManager( // Not adding to failed executors for FetchFailed. isZombie = true - case TaskKilled => - // Not adding to failed executors for TaskKilled. - logWarning("Task %d was killed.".format(tid)) - case ef: ExceptionFailure => - taskMetrics = ef.metrics.getOrElse(null) - if (ef.className == classOf[NotSerializableException].getName()) { + taskMetrics = ef.metrics.orNull + if (ef.className == classOf[NotSerializableException].getName) { // If the task result wasn't serializable, there's no point in trying to re-execute it. - logError("Task %s:%s had a not serializable result: %s; not retrying".format( - taskSet.id, index, ef.description)) - abort("Task %s:%s had a not serializable result: %s".format( - taskSet.id, index, ef.description)) + logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying" + .format(info.id, taskSet.id, tid, ef.description)) + abort("Task %s in stage %s (TID %d) had a not serializable result: %s".format( + info.id, taskSet.id, tid, ef.description)) return } val key = ef.description - failureReason = "Exception failure in TID %s on host %s: %s\n%s".format( - tid, info.host, ef.description, ef.stackTrace.map(" " + _).mkString("\n")) val now = clock.getTime() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { @@ -570,19 +565,18 @@ private[spark] class TaskSetManager( } } if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logWarning("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) + logWarning(failureReason) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + logInfo( + s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " + + s"${ef.className} (${ef.description}) [duplicate $dupCount]") } - case TaskResultLost => - failureReason = "Lost result for TID %s on host %s".format(tid, info.host) + case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) - case _ => - failureReason = "TID %s on host %s failed for unknown reason".format(tid, info.host) + case e: TaskEndReason => + logError("Unknown TaskEndReason: " + e) } // always add to failed executors failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). @@ -593,10 +587,10 @@ private[spark] class TaskSetManager( assert (null != failureReason) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { - logError("Task %s:%d failed %d times; aborting job".format( - taskSet.id, index, maxTaskFailures)) - abort("Task %s:%d failed %d times, most recent failure: %s\nDriver stacktrace:".format( - taskSet.id, index, maxTaskFailures, failureReason)) + logError("Task %d in stage %s failed %d times; aborting job".format( + index, taskSet.id, maxTaskFailures)) + abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" + .format(index, taskSet.id, maxTaskFailures, failureReason)) return } } @@ -709,8 +703,8 @@ private[spark] class TaskSetManager( if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && !speculatableTasks.contains(index)) { logInfo( - "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) + "Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms" + .format(index, taskSet.id, info.host, threshold)) speculatableTasks += index foundTasks = true } @@ -748,7 +742,8 @@ private[spark] class TaskSetManager( pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { + if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 && + pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) { levels += RACK_LOCAL } levels += ANY @@ -761,7 +756,8 @@ private[spark] class TaskSetManager( def newLocAvail(index: Int): Boolean = { for (loc <- tasks(index).preferredLocations) { if (sched.hasExecutorsAliveOnHost(loc.host) || - sched.getRackForHost(loc.host).isDefined) { + (sched.getRackForHost(loc.host).isDefined && + sched.hasHostAliveOnRack(sched.getRackForHost(loc.host).get))) { return true } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 318e16552201c..6abf6d930c155 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -66,4 +66,7 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage + case class AddWebUIFilter(filterName:String, filterParams: String, proxyBase :String) + extends CoarseGrainedClusterMessage + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 05d01b0c821f9..9f085eef46720 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -31,6 +31,7 @@ import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.ui.JettyUtils /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -46,9 +47,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) + var totalExpectedExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + // Submit tasks only after (registered executors / total expected executors) + // is equal to at least this value, that is double between 0 and 1. + var minRegisteredRatio = conf.getDouble("spark.scheduler.minRegisteredExecutorsRatio", 0) + if (minRegisteredRatio > 1) minRegisteredRatio = 1 + // Whatever minRegisteredExecutorsRatio is arrived, submit tasks after the time(milliseconds). + val maxRegisteredWaitingTime = + conf.getInt("spark.scheduler.maxRegisteredExecutorsWaitingTime", 30000) + val createTime = System.currentTimeMillis() + var ready = if (minRegisteredRatio <= 0) true else false class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { private val executorActor = new HashMap[String, ActorRef] @@ -83,6 +94,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A executorAddress(executorId) = sender.path.address addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) + if (executorActor.size >= totalExpectedExecutors.get() * minRegisteredRatio && !ready) { + ready = true + logInfo("SchedulerBackend is ready for scheduling beginning, registered executors: " + + executorActor.size + ", total expected executors: " + totalExpectedExecutors.get() + + ", minRegisteredExecutorsRatio: " + minRegisteredRatio) + } makeOffers() } @@ -120,6 +137,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A removeExecutor(executorId, reason) sender ! true + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + sender ! true case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) @@ -247,6 +267,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A throw new SparkException("Error notifying standalone scheduler's driver actor", e) } } + + override def isReady(): Boolean = { + if (ready) { + return true + } + if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { + ready = true + logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + + "maxRegisteredExecutorsWaitingTime: " + maxRegisteredWaitingTime) + return true + } + false + } + + // Add filters to the SparkUI + def addWebUIFilter(filterName: String, filterParams: String, proxyBase: String) { + if (proxyBase != null && proxyBase.nonEmpty) { + System.setProperty("spark.ui.proxyBase", proxyBase) + } + + if (Seq(filterName, filterParams).forall(t => t != null && t.nonEmpty)) { + logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") + conf.set("spark.ui.filters", filterName) + conf.set(s"spark.$filterName.params", filterParams) + JettyUtils.addFilters(scheduler.sc.ui.getHandlers, conf) + } + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 9c07b3f7b695a..bf2dc88e29048 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -95,6 +95,7 @@ private[spark] class SparkDeploySchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { + totalExpectedExecutors.addAndGet(1) logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 9b95ccca0443e..e9f6273bfd9f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -69,7 +69,7 @@ private[spark] class LocalActor( val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= 1 - executor.launchTask(executorBackend, task.taskId, task.serializedTask) + executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask) } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index a932455776e34..3795994cd920f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -84,7 +84,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics) + context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics) }) new InterruptibleIterator[T](context, completionIter) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 408a797088059..2f0296c20f2e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -180,9 +180,9 @@ object BlockFetcherIterator { if (curRequestSize >= targetRequestSize) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) - curRequestSize = 0 curBlocks = new ArrayBuffer[(BlockId, Long)] logDebug(s"Creating fetch request of $curRequestSize at $address") + curRequestSize = 0 } } // Add in the final request diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 6aed322eeb185..de1cc5539fb48 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -336,6 +336,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case None => blockManagerIdByExecutor(id.executorId) = id } + + logInfo("Registering block manager %s with %s RAM".format( + id.hostPort, Utils.bytesToString(maxMemSize))) + blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxMemSize, slaveActor) } @@ -432,9 +436,6 @@ private[spark] class BlockManagerInfo( // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] - logInfo("Registering block manager %s with %s RAM".format( - blockManagerId.hostPort, Utils.bytesToString(maxMem))) - def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) def updateLastSeenMs() { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 9cb50d9b83dda..e07aa2ee3a5a2 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -136,7 +136,16 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - val uiRoot : String = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("") + def uiRoot: String = { + if (System.getenv("APPLICATION_WEB_PROXY_BASE") != null) { + System.getenv("APPLICATION_WEB_PROXY_BASE") + } else if (System.getProperty("spark.ui.proxyBase") != null) { + System.getProperty("spark.ui.proxyBase") + } + else { + "" + } + } def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 52020954ea57c..0cc51c873727d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import scala.xml.Node import org.apache.spark.ui.{ToolTips, UIUtils} +import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Page showing executor summary */ @@ -64,11 +65,9 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { executorIdToAddress.put(executorId, address) } - val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId) - executorIdToSummary match { - case Some(x) => - x.toSeq.sortBy(_._1).map { case (k, v) => { - // scalastyle:off + listener.stageIdToData.get(stageId) match { + case Some(stageData: StageUIData) => + stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) => {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} @@ -76,16 +75,20 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} - {Utils.bytesToString(v.inputBytes)} - {Utils.bytesToString(v.shuffleRead)} - {Utils.bytesToString(v.shuffleWrite)} - {Utils.bytesToString(v.memoryBytesSpilled)} - {Utils.bytesToString(v.diskBytesSpilled)} + + {Utils.bytesToString(v.inputBytes)} + + {Utils.bytesToString(v.shuffleRead)} + + {Utils.bytesToString(v.shuffleWrite)} + + {Utils.bytesToString(v.memoryBytesSpilled)} + + {Utils.bytesToString(v.diskBytesSpilled)} - // scalastyle:on } - } - case _ => Seq[Node]() + case None => + Seq.empty[Node] } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 2286a7f952f28..efb527b4f03e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -25,6 +25,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId +import org.apache.spark.ui.jobs.UIData._ /** * :: DeveloperApi :: @@ -35,7 +36,7 @@ import org.apache.spark.storage.BlockManagerId * updating the internal data structures concurrently. */ @DeveloperApi -class JobProgressListener(conf: SparkConf) extends SparkListener { +class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { import JobProgressListener._ @@ -46,20 +47,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { val completedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() - // TODO: Should probably consolidate all following into a single hash map. - val stageIdToTime = HashMap[Int, Long]() - val stageIdToInputBytes = HashMap[Int, Long]() - val stageIdToShuffleRead = HashMap[Int, Long]() - val stageIdToShuffleWrite = HashMap[Int, Long]() - val stageIdToMemoryBytesSpilled = HashMap[Int, Long]() - val stageIdToDiskBytesSpilled = HashMap[Int, Long]() - val stageIdToTasksActive = HashMap[Int, HashMap[Long, TaskInfo]]() - val stageIdToTasksComplete = HashMap[Int, Int]() - val stageIdToTasksFailed = HashMap[Int, Int]() - val stageIdToTaskData = HashMap[Int, HashMap[Long, TaskUIData]]() - val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]() - val stageIdToPool = HashMap[Int, String]() - val stageIdToDescription = HashMap[Int, String]() + val stageIdToData = new HashMap[Int, StageUIData] + val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() val executorIdToBlockManagerId = HashMap[String, BlockManagerId]() @@ -71,8 +60,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stageInfo val stageId = stage.stageId - // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage - poolToActiveStages(stageIdToPool(stageId)).remove(stageId) + val stageData = stageIdToData.getOrElseUpdate(stageId, { + logWarning("Stage completed for unknown stage " + stageId) + new StageUIData + }) + + poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId)) activeStages.remove(stageId) if (stage.failureReason.isEmpty) { completedStages += stage @@ -87,21 +80,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > retainedStages) { val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => - stageIdToTime.remove(s.stageId) - stageIdToInputBytes.remove(s.stageId) - stageIdToShuffleRead.remove(s.stageId) - stageIdToShuffleWrite.remove(s.stageId) - stageIdToMemoryBytesSpilled.remove(s.stageId) - stageIdToDiskBytesSpilled.remove(s.stageId) - stageIdToTasksActive.remove(s.stageId) - stageIdToTasksComplete.remove(s.stageId) - stageIdToTasksFailed.remove(s.stageId) - stageIdToTaskData.remove(s.stageId) - stageIdToExecutorSummaries.remove(s.stageId) - stageIdToPool.remove(s.stageId) - stageIdToDescription.remove(s.stageId) - } + stages.take(toRemove).foreach { s => stageIdToData.remove(s.stageId) } stages.trimStart(toRemove) } } @@ -114,26 +93,27 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { val poolName = Option(stageSubmitted.properties).map { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) - stageIdToPool(stage.stageId) = poolName - val description = Option(stageSubmitted.properties).flatMap { + val stageData = stageIdToData.getOrElseUpdate(stage.stageId, new StageUIData) + stageData.schedulingPool = poolName + + stageData.description = Option(stageSubmitted.properties).flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } - description.map(d => stageIdToDescription(stage.stageId) = d) val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]()) stages(stage.stageId) = stage } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val sid = taskStart.stageId val taskInfo = taskStart.taskInfo if (taskInfo != null) { - val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashMap[Long, TaskInfo]()) - tasksActive(taskInfo.taskId) = taskInfo - val taskMap = stageIdToTaskData.getOrElse(sid, HashMap[Long, TaskUIData]()) - taskMap(taskInfo.taskId) = new TaskUIData(taskInfo) - stageIdToTaskData(sid) = taskMap + val stageData = stageIdToData.getOrElseUpdate(taskStart.stageId, { + logWarning("Task start for unknown stage " + taskStart.stageId) + new StageUIData + }) + stageData.numActiveTasks += 1 + stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) } } @@ -143,88 +123,76 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { } override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { - val sid = taskEnd.stageId val info = taskEnd.taskInfo - if (info != null) { + val stageData = stageIdToData.getOrElseUpdate(taskEnd.stageId, { + logWarning("Task end for unknown stage " + taskEnd.stageId) + new StageUIData + }) + // create executor summary map if necessary - val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid, - op = new HashMap[String, ExecutorSummary]()) + val executorSummaryMap = stageData.executorSummary executorSummaryMap.getOrElseUpdate(key = info.executorId, op = new ExecutorSummary) - val executorSummary = executorSummaryMap.get(info.executorId) - executorSummary match { - case Some(y) => { - // first update failed-task, succeed-task - taskEnd.reason match { - case Success => - y.succeededTasks += 1 - case _ => - y.failedTasks += 1 - } - - // update duration - y.taskTime += info.duration - - val metrics = taskEnd.taskMetrics - if (metrics != null) { - metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } - metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } - metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } - y.memoryBytesSpilled += metrics.memoryBytesSpilled - y.diskBytesSpilled += metrics.diskBytesSpilled - } + executorSummaryMap.get(info.executorId).foreach { y => + // first update failed-task, succeed-task + taskEnd.reason match { + case Success => + y.succeededTasks += 1 + case _ => + y.failedTasks += 1 + } + + // update duration + y.taskTime += info.duration + + val metrics = taskEnd.taskMetrics + if (metrics != null) { + metrics.inputMetrics.foreach { y.inputBytes += _.bytesRead } + metrics.shuffleReadMetrics.foreach { y.shuffleRead += _.remoteBytesRead } + metrics.shuffleWriteMetrics.foreach { y.shuffleWrite += _.shuffleBytesWritten } + y.memoryBytesSpilled += metrics.memoryBytesSpilled + y.diskBytesSpilled += metrics.diskBytesSpilled } - case _ => {} } - val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashMap[Long, TaskInfo]()) - // Remove by taskId, rather than by TaskInfo, in case the TaskInfo is from storage - tasksActive.remove(info.taskId) + stageData.numActiveTasks -= 1 val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { case org.apache.spark.Success => - stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1 + stageData.numCompleteTasks += 1 (None, Option(taskEnd.taskMetrics)) case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics - stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + stageData.numFailedTasks += 1 (Some(e.toErrorString), e.metrics) case e: TaskFailedReason => // All other failure cases - stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 + stageData.numFailedTasks += 1 (Some(e.toErrorString), None) } - stageIdToTime.getOrElseUpdate(sid, 0L) - val time = metrics.map(_.executorRunTime).getOrElse(0L) - stageIdToTime(sid) += time - stageIdToInputBytes.getOrElseUpdate(sid, 0L) + val taskRunTime = metrics.map(_.executorRunTime).getOrElse(0L) + stageData.executorRunTime += taskRunTime val inputBytes = metrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L) - stageIdToInputBytes(sid) += inputBytes + stageData.inputBytes += inputBytes - stageIdToShuffleRead.getOrElseUpdate(sid, 0L) val shuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L) - stageIdToShuffleRead(sid) += shuffleRead + stageData.shuffleReadBytes += shuffleRead - stageIdToShuffleWrite.getOrElseUpdate(sid, 0L) val shuffleWrite = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleBytesWritten).getOrElse(0L) - stageIdToShuffleWrite(sid) += shuffleWrite + stageData.shuffleWriteBytes += shuffleWrite - stageIdToMemoryBytesSpilled.getOrElseUpdate(sid, 0L) val memoryBytesSpilled = metrics.map(_.memoryBytesSpilled).getOrElse(0L) - stageIdToMemoryBytesSpilled(sid) += memoryBytesSpilled + stageData.memoryBytesSpilled += memoryBytesSpilled - stageIdToDiskBytesSpilled.getOrElseUpdate(sid, 0L) val diskBytesSpilled = metrics.map(_.diskBytesSpilled).getOrElse(0L) - stageIdToDiskBytesSpilled(sid) += diskBytesSpilled + stageData.diskBytesSpilled += diskBytesSpilled - val taskMap = stageIdToTaskData.getOrElse(sid, HashMap[Long, TaskUIData]()) - taskMap(info.taskId) = new TaskUIData(info, metrics, errorMessage) - stageIdToTaskData(sid) = taskMap + stageData.taskData(info.taskId) = new TaskUIData(info, metrics, errorMessage) } - } + } // end of onTaskEnd override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { @@ -252,12 +220,6 @@ class JobProgressListener(conf: SparkConf) extends SparkListener { } -@DeveloperApi -case class TaskUIData( - taskInfo: TaskInfo, - taskMetrics: Option[TaskMetrics] = None, - errorMessage: Option[String] = None) - private object JobProgressListener { val DEFAULT_POOL_NAME = "default" val DEFAULT_RETAINED_STAGES = 1000 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8c3821bd7c3eb..cab26b9e2f7d3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} +import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} /** Page showing statistics and task list for a given stage */ @@ -34,8 +35,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { val stageId = request.getParameter("id").toInt + val stageDataOption = listener.stageIdToData.get(stageId) - if (!listener.stageIdToTaskData.contains(stageId)) { + if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) { val content =

Summary Metrics

No tasks have started yet @@ -45,23 +47,14 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { "Details for Stage %s".format(stageId), parent.headerTabs, parent) } - val tasks = listener.stageIdToTaskData(stageId).values.toSeq.sortBy(_.taskInfo.launchTime) + val stageData = stageDataOption.get + val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) val numCompleted = tasks.count(_.taskInfo.finished) - val inputBytes = listener.stageIdToInputBytes.getOrElse(stageId, 0L) - val hasInput = inputBytes > 0 - val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L) - val hasShuffleRead = shuffleReadBytes > 0 - val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L) - val hasShuffleWrite = shuffleWriteBytes > 0 - val memoryBytesSpilled = listener.stageIdToMemoryBytesSpilled.getOrElse(stageId, 0L) - val diskBytesSpilled = listener.stageIdToDiskBytesSpilled.getOrElse(stageId, 0L) - val hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 - - var activeTime = 0L - val now = System.currentTimeMillis - val tasksActive = listener.stageIdToTasksActive(stageId).values - tasksActive.foreach(activeTime += _.timeRunning(now)) + val hasInput = stageData.inputBytes > 0 + val hasShuffleRead = stageData.shuffleReadBytes > 0 + val hasShuffleWrite = stageData.shuffleWriteBytes > 0 + val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 // scalastyle:off val summary = @@ -69,34 +62,34 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
  • Total task time across all tasks: - {UIUtils.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} + {UIUtils.formatDuration(stageData.executorRunTime)}
  • {if (hasInput)
  • Input: - {Utils.bytesToString(inputBytes)} + {Utils.bytesToString(stageData.inputBytes)}
  • } {if (hasShuffleRead)
  • Shuffle read: - {Utils.bytesToString(shuffleReadBytes)} + {Utils.bytesToString(stageData.shuffleReadBytes)}
  • } {if (hasShuffleWrite)
  • Shuffle write: - {Utils.bytesToString(shuffleWriteBytes)} + {Utils.bytesToString(stageData.shuffleWriteBytes)}
  • } {if (hasBytesSpilled)
  • Shuffle spill (memory): - {Utils.bytesToString(memoryBytesSpilled)} + {Utils.bytesToString(stageData.memoryBytesSpilled)}
  • Shuffle spill (disk): - {Utils.bytesToString(diskBytesSpilled)} + {Utils.bytesToString(stageData.diskBytesSpilled)}
  • }
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 4013c6f49936c..5f45c0ced5ec5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -17,12 +17,11 @@ package org.apache.spark.ui.jobs -import java.util.Date - -import scala.collection.mutable.HashMap import scala.xml.Node -import org.apache.spark.scheduler.{StageInfo, TaskInfo} +import java.util.Date + +import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -71,14 +70,14 @@ private[ui] class StageTableBase( } - private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = + private def makeProgressBar(started: Int, completed: Int, failed: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100)
- {completed}/{total} {failed} + {completed}/{total} { if (failed > 0) s"($failed failed)" else "" }
@@ -89,7 +88,8 @@ private[ui] class StageTableBase( // scalastyle:off val killLink = if (killEnabled) { - (kill) + (kill) } // scalastyle:on @@ -107,13 +107,23 @@ private[ui] class StageTableBase( } - listener.stageIdToDescription.get(s.stageId) - .map(d =>
{d}
{nameLink} {killLink}
) - .getOrElse(
{killLink} {nameLink} {details}
) + val stageDataOption = listener.stageIdToData.get(s.stageId) + // Too many nested map/flatMaps with options are just annoying to read. Do this imperatively. + if (stageDataOption.isDefined && stageDataOption.get.description.isDefined) { + val desc = stageDataOption.get.description +
{desc}
{nameLink} {killLink}
+ } else { +
{killLink} {nameLink} {details}
+ } } protected def stageRow(s: StageInfo): Seq[Node] = { - val poolName = listener.stageIdToPool.get(s.stageId) + val stageDataOption = listener.stageIdToData.get(s.stageId) + if (stageDataOption.isEmpty) { + return {s.stageId}No data available for this stage + } + + val stageData = stageDataOption.get val submissionTime = s.submissionTime match { case Some(t) => UIUtils.formatDate(new Date(t)) case None => "Unknown" @@ -123,35 +133,20 @@ private[ui] class StageTableBase( if (finishTime > t) finishTime - t else System.currentTimeMillis - t } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") - val startedTasks = - listener.stageIdToTasksActive.getOrElse(s.stageId, HashMap[Long, TaskInfo]()).size - val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0) - val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match { - case f if f > 0 => "(%s failed)".format(f) - case _ => "" - } - val totalTasks = s.numTasks - val inputSortable = listener.stageIdToInputBytes.getOrElse(s.stageId, 0L) - val inputRead = inputSortable match { - case 0 => "" - case b => Utils.bytesToString(b) - } - val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) - val shuffleRead = shuffleReadSortable match { - case 0 => "" - case b => Utils.bytesToString(b) - } - val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) - val shuffleWrite = shuffleWriteSortable match { - case 0 => "" - case b => Utils.bytesToString(b) - } + + val inputRead = stageData.inputBytes + val inputReadWithUnit = if (inputRead > 0) Utils.bytesToString(inputRead) else "" + val shuffleRead = stageData.shuffleReadBytes + val shuffleReadWithUnit = if (shuffleRead > 0) Utils.bytesToString(shuffleRead) else "" + val shuffleWrite = stageData.shuffleWriteBytes + val shuffleWriteWithUnit = if (shuffleWrite > 0) Utils.bytesToString(shuffleWrite) else "" + {s.stageId} ++ {if (isFairScheduler) { - {poolName.get} + .format(UIUtils.prependBaseUri(basePath), stageData.schedulingPool)}> + {stageData.schedulingPool} } else { @@ -161,11 +156,12 @@ private[ui] class StageTableBase( {submissionTime} {formattedDuration} - {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)} + {makeProgressBar(stageData.numActiveTasks, stageData.numCompleteTasks, + stageData.numFailedTasks, s.numTasks)} - {inputRead} - {shuffleRead} - {shuffleWrite} + {inputReadWithUnit} + {shuffleReadWithUnit} + {shuffleWriteWithUnit} } /** Render an HTML row that represents a stage */ diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala new file mode 100644 index 0000000000000..be11a11695b01 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.TaskInfo + +import scala.collection.mutable.HashMap + +private[jobs] object UIData { + + class ExecutorSummary { + var taskTime : Long = 0 + var failedTasks : Int = 0 + var succeededTasks : Int = 0 + var inputBytes : Long = 0 + var shuffleRead : Long = 0 + var shuffleWrite : Long = 0 + var memoryBytesSpilled : Long = 0 + var diskBytesSpilled : Long = 0 + } + + class StageUIData { + var numActiveTasks: Int = _ + var numCompleteTasks: Int = _ + var numFailedTasks: Int = _ + + var executorRunTime: Long = _ + + var inputBytes: Long = _ + var shuffleReadBytes: Long = _ + var shuffleWriteBytes: Long = _ + var memoryBytesSpilled: Long = _ + var diskBytesSpilled: Long = _ + + var schedulingPool: String = "" + var description: Option[String] = None + + var taskData = new HashMap[Long, TaskUIData] + var executorSummary = new HashMap[String, ExecutorSummary] + } + + case class TaskUIData( + taskInfo: TaskInfo, + taskMetrics: Option[TaskMetrics] = None, + errorMessage: Option[String] = None) +} diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index 6a95dc06e155d..9dcdafdd6350e 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -196,6 +196,5 @@ private[spark] class FileLogger( def stop() { hadoopDataStream.foreach(_.close()) writer.foreach(_.close()) - fileSystem.close() } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 47eb44b530379..2ff8b25a56d10 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -527,8 +527,9 @@ private[spark] object JsonProtocol { metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long] metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long] metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long] - metrics.shuffleReadMetrics = - Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson) + Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics => + metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics)) + } metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.inputMetrics = diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a2454e120a8ab..5784e974fbb67 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -44,7 +44,7 @@ import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} /** CallSite represents a place in user code. It can have a short and a long form. */ -private[spark] case class CallSite(val short: String, val long: String) +private[spark] case class CallSite(short: String, long: String) /** * Various utility methods used by Spark. @@ -809,7 +809,12 @@ private[spark] object Utils extends Logging { */ def getCallSite: CallSite = { val trace = Thread.currentThread.getStackTrace() - .filterNot(_.getMethodName.contains("getStackTrace")) + .filterNot { ste:StackTraceElement => + // When running under some profilers, the current stack trace might contain some bogus + // frames. This is intended to ensure that we don't crash in these situations by + // ignoring any frames that we can't examine. + (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")) + } // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD @@ -1286,4 +1291,19 @@ private[spark] object Utils extends Logging { } } + /** Return a nice string representation of the exception, including the stack trace. */ + def exceptionString(e: Exception): String = { + if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace) + } + + /** Return a nice string representation of the exception, including the stack trace. */ + def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } + } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 292d0962f4fdb..765254bf4c36e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -268,10 +268,10 @@ class ExternalAppendOnlyMap[K, V, C]( private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { var i = 0 while (i < buffer.pairs.length) { - val (k, c) = buffer.pairs(i) - if (k == key) { + val pair = buffer.pairs(i) + if (pair._1 == key) { buffer.pairs.remove(i) - return mergeCombiners(baseCombiner, c) + return mergeCombiners(baseCombiner, pair._2) } i += 1 } @@ -293,9 +293,11 @@ class ExternalAppendOnlyMap[K, V, C]( } // Select a key from the StreamBuffer that holds the lowest key hash val minBuffer = mergeHeap.dequeue() - val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) + val minPairs = minBuffer.pairs + val minHash = minBuffer.minKeyHash val minPair = minPairs.remove(0) - var (minKey, minCombiner) = minPair + val minKey = minPair._1 + var minCombiner = minPair._2 assert(getKeyHashCode(minPair) == minHash) // For all other streams that may have this key (i.e. have the same minimum key hash), diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index a79e3ee756fc6..d10141b90e621 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -17,8 +17,54 @@ package org.apache.spark.util.random +import scala.reflect.ClassTag +import scala.util.Random + private[spark] object SamplingUtils { + /** + * Reservoir sampling implementation that also returns the input size. + * + * @param input input size + * @param k reservoir size + * @param seed random seed + * @return (samples, input size) + */ + def reservoirSampleAndCount[T: ClassTag]( + input: Iterator[T], + k: Int, + seed: Long = Random.nextLong()) + : (Array[T], Int) = { + val reservoir = new Array[T](k) + // Put the first k elements in the reservoir. + var i = 0 + while (i < k && input.hasNext) { + val item = input.next() + reservoir(i) = item + i += 1 + } + + // If we have consumed all the elements, return them. Otherwise do the replacement. + if (i < k) { + // If input size < k, trim the array to return only an array of input size. + val trimReservoir = new Array[T](i) + System.arraycopy(reservoir, 0, trimReservoir, 0, i) + (trimReservoir, i) + } else { + // If input size > k, continue the sampling process. + val rand = new XORShiftRandom(seed) + while (input.hasNext) { + val item = input.next() + val replacementIndex = rand.nextInt(i) + if (replacementIndex < k) { + reservoir(replacementIndex) = item + } + i += 1 + } + (reservoir, i) + } + } + /** * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of * the time. diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index e755d2e309398..2229e6acc425d 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -104,8 +104,9 @@ class FailureSuite extends FunSuite with LocalSparkContext { results.collect() } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException") || - thrown.getCause.getClass === classOf[NotSerializableException]) + assert(thrown.getMessage.contains("serializable") || + thrown.getCause.getClass === classOf[NotSerializableException], + "Exception does not contain \"serializable\": " + thrown.getMessage) FailureSuiteState.clear() } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 070e974657860..c70e22cf09433 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -177,6 +177,31 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(output.collect().toList === List((1, "a"), (2, "aa"), (3, "aaa"))) } + test("object files of classes from a JAR") { + val original = Thread.currentThread().getContextClassLoader + val className = "FileSuiteObjectFileTest" + val jar = TestUtils.createJarWithClasses(Seq(className)) + val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) + Thread.currentThread().setContextClassLoader(loader) + try { + sc = new SparkContext("local", "test") + val objs = sc.makeRDD(1 to 3).map { x => + val loader = Thread.currentThread().getContextClassLoader + Class.forName(className, true, loader).newInstance() + } + val outputDir = new File(tempDir, "output").getAbsolutePath + objs.saveAsObjectFile(outputDir) + // Try reading the output back as an object file + val ct = reflect.ClassTag[Any](Class.forName(className, true, loader)) + val output = sc.objectFile[Any](outputDir) + assert(output.collect().size === 3) + assert(output.collect().head.getClass.getName === className) + } + finally { + Thread.currentThread().setContextClassLoader(original) + } + } + test("write SequenceFile using new Hadoop API") { import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c4f2f7e34f4d5..237e644b48e49 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -240,7 +240,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("NotSerializableException")) + assert(thrown.getMessage.toLowerCase.contains("serializable")) } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala similarity index 98% rename from core/src/test/scala/org/apache/spark/BroadcastSuite.scala rename to core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c9936256a5b95..7c3d0208b195a 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -15,14 +15,12 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.broadcast +import org.apache.spark.storage.{BroadcastBlockId, _} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.scalatest.FunSuite -import org.apache.spark.storage._ -import org.apache.spark.broadcast.{Broadcast, HttpBroadcast} -import org.apache.spark.storage.BroadcastBlockId - class BroadcastSuite extends FunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 68a0ea36aa545..3f882a724b047 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -46,7 +46,13 @@ class CompressionCodecSuite extends FunSuite { test("default compression codec") { val codec = CompressionCodec.createCodec(conf) - assert(codec.getClass === classOf[LZFCompressionCodec]) + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } + + test("lz4 compression codec") { + val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) + assert(codec.getClass === classOf[LZ4CompressionCodec]) testCodec(codec) } diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala similarity index 97% rename from core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala rename to core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index df6b2604c8d8a..415ad8c432c12 100644 --- a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -15,15 +15,14 @@ * limitations under the License. */ -package org.apache.spark - -import org.scalatest.FunSuite +package org.apache.spark.network import java.nio._ -import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId} -import scala.concurrent.Await -import scala.concurrent.TimeoutException +import org.apache.spark.{SecurityManager, SparkConf} +import org.scalatest.FunSuite + +import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala similarity index 95% rename from core/src/test/scala/org/apache/spark/PipedRDDSuite.scala rename to core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index db56a4acdd6f5..be972c5e97a7e 100644 --- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -15,25 +15,21 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.rdd import java.io.File -import org.scalatest.FunSuite - -import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition} -import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit} import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} +import org.apache.spark._ +import org.scalatest.FunSuite import scala.collection.Map import scala.language.postfixOps import scala.sys.process._ import scala.util.Try -import org.apache.hadoop.io.{Text, LongWritable} - -import org.apache.spark.executor.TaskMetrics - class PipedRDDSuite extends FunSuite with SharedSparkContext { test("basic pipe") { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 0f9cbe213ea17..2924de112934c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -351,6 +351,20 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + // Test for SPARK-2412 -- ensure that the second pass of the algorithm does not throw an exception + test("coalesced RDDs with locality, fail first pass") { + val initialPartitions = 1000 + val targetLen = 50 + val couponCount = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt // = 492 + + val blocks = (1 to initialPartitions).map { i => + (i, List(if (i > couponCount) "m2" else "m1")) + } + val data = sc.makeRDD(blocks) + val coalesced = data.coalesce(targetLen) + assert(coalesced.partitions.length == targetLen) + } + test("zipped RDDs") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) val zipped = nums.zip(nums.map(_ + 1.0)) @@ -379,6 +393,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("mapWith") { import java.util.Random val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + @deprecated("suppress compile time deprecation warning", "1.0.0") val randoms = ones.mapWith( (index: Int) => new Random(index + 42)) {(t: Int, prng: Random) => prng.nextDouble * t}.collect() @@ -397,6 +412,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("flatMapWith") { import java.util.Random val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) + @deprecated("suppress compile time deprecation warning", "1.0.0") val randoms = ones.flatMapWith( (index: Int) => new Random(index + 42)) {(t: Int, prng: Random) => @@ -418,6 +434,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("filterWith") { import java.util.Random val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) + @deprecated("suppress compile time deprecation warning", "1.0.0") val sample = ints.filterWith( (index: Int) => new Random(index + 42)) {(t: Int, prng: Random) => prng.nextInt(3) == 0}. diff --git a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala similarity index 95% rename from core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala rename to core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala index 4f87fd8654c4a..72596e86865b2 100644 --- a/core/src/test/scala/org/apache/spark/ZippedPartitionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ZippedPartitionsSuite.scala @@ -15,8 +15,9 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.rdd +import org.apache.spark.SharedSparkContext import org.scalatest.FunSuite object ZippedPartitionsSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 71f48e295ecca..3b0b8e2f68c97 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -258,8 +258,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers if (stageInfo.rddInfos.exists(_.name == d4.name)) { taskMetrics.shuffleReadMetrics should be ('defined) val sm = taskMetrics.shuffleReadMetrics.get - sm.totalBlocksFetched should be > (0) - sm.localBlocksFetched should be > (0) + sm.totalBlocksFetched should be (128) + sm.localBlocksFetched should be (128) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0l) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ff2a487005c4..86b443b18f2a6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -54,6 +54,23 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) } } +// Get the rack for a given host +object FakeRackUtil { + private val hostToRack = new mutable.HashMap[String, String]() + + def cleanUp() { + hostToRack.clear() + } + + def assignHostToRack(host: String, rack: String) { + hostToRack(host) = rack + } + + def getRackForHost(host: String) = { + hostToRack.get(host) + } +} + /** * A mock TaskSchedulerImpl implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with @@ -69,6 +86,9 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex val taskSetsFailed = new ArrayBuffer[String] val executors = new mutable.HashMap[String, String] ++ liveExecutors + for ((execId, host) <- liveExecutors; rack <- getRackForHost(host)) { + hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host + } dagScheduler = new FakeDAGScheduler(sc, this) @@ -82,7 +102,12 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex def addExecutor(execId: String, host: String) { executors.put(execId, host) + for (rack <- getRackForHost(host)) { + hostsByRack.getOrElseUpdate(rack, new mutable.HashSet[String]()) += host + } } + + override def getRackForHost(value: String): Option[String] = FakeRackUtil.getRackForHost(value) } /** @@ -419,6 +444,9 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } test("new executors get added") { + // Assign host2 to rack2 + FakeRackUtil.cleanUp() + FakeRackUtil.assignHostToRack("host2", "rack2") sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc) val taskSet = FakeTask.createTaskSet(4, @@ -444,8 +472,39 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { manager.executorAdded() // No-pref list now only contains task 3 assert(manager.pendingTasksWithNoPrefs.size === 1) - // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY - assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements( + Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) + } + + test("test RACK_LOCAL tasks") { + FakeRackUtil.cleanUp() + // Assign host1 to rack1 + FakeRackUtil.assignHostToRack("host1", "rack1") + // Assign host2 to rack1 + FakeRackUtil.assignHostToRack("host2", "rack1") + // Assign host3 to rack2 + FakeRackUtil.assignHostToRack("host3", "rack2") + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, + ("execA", "host1"), ("execB", "host2"), ("execC", "host3")) + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execA"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) + // Set allowed locality to ANY + clock.advance(LOCALITY_WAIT * 3) + // Offer host3 + // No task is scheduled if we restrict locality to RACK_LOCAL + assert(manager.resourceOffer("execC", "host3", RACK_LOCAL) === None) + // Task 0 can be scheduled with ANY + assert(manager.resourceOffer("execC", "host3", ANY).get.index === 0) + // Offer host2 + // Task 1 can be scheduled with RACK_LOCAL + assert(manager.resourceOffer("execB", "host2", RACK_LOCAL).get.index === 1) } test("do not emit warning when serialized task is small") { diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 5d15a68ac7e4f..aad6599589420 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -15,15 +15,12 @@ * limitations under the License. */ -package org.apache.spark.serializer; - -import java.io.NotSerializableException +package org.apache.spark.serializer import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkException} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkException -import org.apache.spark.SharedSparkContext /* A trivial (but unserializable) container for trivial functions */ class UnserializableClass { @@ -38,52 +35,50 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex test("throws expected serialization exceptions on actions") { val (data, uc) = fixture - val ex = intercept[SparkException] { - data.map(uc.op(_)).count + data.map(uc.op(_)).count() } - assert(ex.getMessage.contains("Task not serializable")) } // There is probably a cleaner way to eliminate boilerplate here, but we're // iterating over a map from transformation names to functions that perform that // transformation on a given RDD, creating one test case for each - + for (transformation <- - Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, - "mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _, + Map("map" -> xmap _, + "flatMap" -> xflatMap _, + "filter" -> xfilter _, + "mapPartitions" -> xmapPartitions _, "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, - "mapPartitionsWithContext" -> xmapPartitionsWithContext _, - "filterWith" -> xfilterWith _)) { + "mapPartitionsWithContext" -> xmapPartitionsWithContext _)) { val (name, xf) = transformation - + test(s"$name transformations throw proactive serialization exceptions") { val (data, uc) = fixture - val ex = intercept[SparkException] { xf(data, uc) } - assert(ex.getMessage.contains("Task not serializable"), s"RDD.$name doesn't proactively throw NotSerializableException") } } - + private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.map(y=>uc.op(y)) - private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapWith(x => x.toString)((x,y)=>x + uc.op(y)) + private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y=>Seq(uc.op(y))) + private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y=>uc.pred(y)) - private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.filterWith(x => x.toString)((x,y)=>uc.pred(y)) + private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitions(_.map(y=>uc.op(y))) + private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) + private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index fa43b66c6cb5a..b52f81877d557 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -47,11 +47,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc } listener.completedStages.size should be (5) - listener.completedStages.filter(_.stageId == 50).size should be (1) - listener.completedStages.filter(_.stageId == 49).size should be (1) - listener.completedStages.filter(_.stageId == 48).size should be (1) - listener.completedStages.filter(_.stageId == 47).size should be (1) - listener.completedStages.filter(_.stageId == 46).size should be (1) + listener.completedStages.count(_.stageId == 50) should be (1) + listener.completedStages.count(_.stageId == 49) should be (1) + listener.completedStages.count(_.stageId == 48) should be (1) + listener.completedStages.count(_.stageId == 47) should be (1) + listener.completedStages.count(_.stageId == 46) should be (1) } test("test executor id to summary") { @@ -59,20 +59,18 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val listener = new JobProgressListener(conf) val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() - - // nothing in it - assert(listener.stageIdToExecutorSummaries.size == 0) + assert(listener.stageIdToData.size === 0) // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) + taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0, null, null, 0, null) val taskType = Utils.getFormattedClassName(task) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) - .shuffleRead == 1000) + assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) + .shuffleRead === 1000) // finish a task with unknown executor-id, nothing should happen taskInfo = @@ -80,27 +78,23 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.size == 1) + assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration - shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail()) - .shuffleRead == 2000) + assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) + .shuffleRead === 2000) // finish this task, should get updated duration - shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics) taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 task = new ShuffleMapTask(0, null, null, 0, null) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) - assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail()) - .shuffleRead == 1000) + assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) + .shuffleRead === 1000) } test("test task success vs failure counting for different task end reasons") { @@ -121,13 +115,17 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc TaskKilled, ExecutorLostFailure, UnknownReason) + var failCount = 0 for (reason <- taskFailedReasons) { listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, reason, taskInfo, metrics)) - assert(listener.stageIdToTasksComplete.get(task.stageId) === None) + failCount += 1 + assert(listener.stageIdToData(task.stageId).numCompleteTasks === 0) + assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) } // Make sure we count success as success. listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, metrics)) - assert(listener.stageIdToTasksComplete.get(task.stageId) === Some(1)) + assert(listener.stageIdToData(task.stageId).numCompleteTasks === 1) + assert(listener.stageIdToData(task.stageId).numFailedTasks === failCount) } } diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala similarity index 99% rename from core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala rename to core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 4ab870e751778..c4765e53de17b 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark - -import org.scalatest.FunSuite +package org.apache.spark.util import akka.actor._ +import org.apache.spark._ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils +import org.scalatest.FunSuite + import scala.concurrent.Await /** diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index ca37d707b06ca..d2bee448d4d3b 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -135,12 +135,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging { val testOutputStream = new PipedOutputStream() val testInputStream = new PipedInputStream(testOutputStream) val appender = FileAppender(testInputStream, testFile, conf) - assert(appender.isInstanceOf[ExpectedAppender]) + //assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) if (appender.isInstanceOf[RollingFileAppender]) { val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy - rollingPolicy.isInstanceOf[ExpectedRollingPolicy] val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) { rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis } else { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 058d31453081a..11f70a6090d24 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -518,7 +518,7 @@ class JsonProtocolSuite extends FunSuite { sr.localBlocksFetched = e sr.fetchWaitTime = a + d sr.remoteBlocksFetched = f - t.shuffleReadMetrics = Some(sr) + t.updateShuffleReadMetrics(sr) } sw.shuffleBytesWritten = a + b + c sw.shuffleWriteTime = b + c + d diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala index 7006571ef0ef6..794a55d61750b 100644 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.FunSuite /** * Tests org.apache.spark.util.Vector functionality */ +@deprecated("suppress compile time deprecation warning", "1.0.0") class VectorSuite extends FunSuite { def verifyVector(vector: Vector, expectedLength: Int) = { diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index accfe2e9b7f2a..73a9d029b0248 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -17,11 +17,32 @@ package org.apache.spark.util.random +import scala.util.Random + import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.scalatest.FunSuite class SamplingUtilsSuite extends FunSuite { + test("reservoirSampleAndCount") { + val input = Seq.fill(100)(Random.nextInt()) + + // input size < k + val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150) + assert(count1 === 100) + assert(input === sample1.toSeq) + + // input size == k + val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100) + assert(count2 === 100) + assert(input === sample2.toSeq) + + // input size > k + val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10) + assert(count3 === 100) + assert(sample3.length === 10) + } + test("computeFraction") { // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001 diff --git a/mllib/data/als/test.data b/data/mllib/als/test.data similarity index 100% rename from mllib/data/als/test.data rename to data/mllib/als/test.data diff --git a/data/kmeans_data.txt b/data/mllib/kmeans_data.txt similarity index 100% rename from data/kmeans_data.txt rename to data/mllib/kmeans_data.txt diff --git a/mllib/data/lr-data/random.data b/data/mllib/lr-data/random.data similarity index 100% rename from mllib/data/lr-data/random.data rename to data/mllib/lr-data/random.data diff --git a/data/lr_data.txt b/data/mllib/lr_data.txt similarity index 100% rename from data/lr_data.txt rename to data/mllib/lr_data.txt diff --git a/data/pagerank_data.txt b/data/mllib/pagerank_data.txt similarity index 100% rename from data/pagerank_data.txt rename to data/mllib/pagerank_data.txt diff --git a/mllib/data/ridge-data/lpsa.data b/data/mllib/ridge-data/lpsa.data similarity index 100% rename from mllib/data/ridge-data/lpsa.data rename to data/mllib/ridge-data/lpsa.data diff --git a/mllib/data/sample_libsvm_data.txt b/data/mllib/sample_libsvm_data.txt similarity index 100% rename from mllib/data/sample_libsvm_data.txt rename to data/mllib/sample_libsvm_data.txt diff --git a/mllib/data/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt similarity index 100% rename from mllib/data/sample_naive_bayes_data.txt rename to data/mllib/sample_naive_bayes_data.txt diff --git a/mllib/data/sample_svm_data.txt b/data/mllib/sample_svm_data.txt similarity index 100% rename from mllib/data/sample_svm_data.txt rename to data/mllib/sample_svm_data.txt diff --git a/mllib/data/sample_tree_data.csv b/data/mllib/sample_tree_data.csv similarity index 100% rename from mllib/data/sample_tree_data.csv rename to data/mllib/sample_tree_data.csv diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 49bf78f60763a..38830103d1e8d 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -95,7 +95,7 @@ make_binary_release() { cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME - ./make-distribution.sh $FLAGS --name $NAME --tgz + ./make-distribution.sh --name $NAME --tgz $FLAGS cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . rm -rf spark-$RELEASE_VERSION-bin-$NAME @@ -111,9 +111,10 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "--with-hive --hadoop 1.0.4" -make_binary_release "cdh4" "--with-hive --hadoop 2.0.0-mr1-cdh4.2.0" -make_binary_release "hadoop2" "--with-hive --with-yarn --hadoop 2.2.0" +make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" +make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop2" \ + "-Phive -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" # Copy data echo "Copying release tarballs" diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py new file mode 100755 index 0000000000000..8051080117062 --- /dev/null +++ b/dev/github_jira_sync.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Utility for updating JIRA's with information about Github pull requests + +import json +import os +import re +import sys +import urllib2 + +try: + import jira.client +except ImportError: + print "This tool requires the jira-python library" + print "Install using 'sudo pip install jira-python'" + sys.exit(-1) + +# User facing configs +GITHUB_API_BASE = os.environ.get("GITHUB_API_BASE", "https://api.github.com/repos/apache/spark") +JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "apachespark") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "XXX") +# Maximum number of updates to perform in one run +MAX_UPDATES = int(os.environ.get("MAX_UPDATES", "100000")) +# Cut-off for oldest PR on which to comment. Useful for avoiding +# "notification overload" when running for the first time. +MIN_COMMENT_PR = int(os.environ.get("MIN_COMMENT_PR", "1496")) + +# File used as an opitimization to store maximum previously seen PR +# Used mostly because accessing ASF JIRA is slow, so we want to avoid checking +# the state of JIRA's that are tied to PR's we've already looked at. +MAX_FILE = ".github-jira-max" + +def get_url(url): + try: + return urllib2.urlopen(url) + except urllib2.HTTPError as e: + print "Unable to fetch URL, exiting: %s" % url + sys.exit(-1) + +def get_json(urllib_response): + return json.load(urllib_response) + +# Return a list of (JIRA id, JSON dict) tuples: +# e.g. [('SPARK-1234', {.. json ..}), ('SPARK-5687', {.. json ..})} +def get_jira_prs(): + result = [] + has_next_page = True + page_num = 0 + while has_next_page: + page = get_url(GITHUB_API_BASE + "/pulls?page=%s&per_page=100" % page_num) + page_json = get_json(page) + + for pull in page_json: + jiras = re.findall("SPARK-[0-9]{4,5}", pull['title']) + for jira in jiras: + result = result + [(jira, pull)] + + # Check if there is another page + link_header = filter(lambda k: k.startswith("Link"), page.info().headers)[0] + if not "next"in link_header: + has_next_page = False + else: + page_num = page_num + 1 + return result + +def set_max_pr(max_val): + f = open(MAX_FILE, 'w') + f.write("%s" % max_val) + f.close() + print "Writing largest PR number seen: %s" % max_val + +def get_max_pr(): + if os.path.exists(MAX_FILE): + result = int(open(MAX_FILE, 'r').read()) + print "Read largest PR number previously seen: %s" % result + return result + else: + return 0 + +jira_client = jira.client.JIRA({'server': JIRA_API_BASE}, + basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) + +jira_prs = get_jira_prs() + +previous_max = get_max_pr() +print "Retrieved %s JIRA PR's from Github" % len(jira_prs) +jira_prs = [(k, v) for k, v in jira_prs if int(v['number']) > previous_max] +print "%s PR's remain after excluding visted ones" % len(jira_prs) + +num_updates = 0 +considered = [] +for issue, pr in sorted(jira_prs, key=lambda (k, v): int(v['number'])): + if num_updates >= MAX_UPDATES: + break + pr_num = int(pr['number']) + + print "Checking issue %s" % issue + considered = considered + [pr_num] + + url = pr['html_url'] + title = "[Github] Pull Request #%s (%s)" % (pr['number'], pr['user']['login']) + try: + existing_links = map(lambda l: l.raw['object']['url'], jira_client.remote_links(issue)) + except: + print "Failure reading JIRA %s (does it exist?)" % issue + print sys.exc_info()[0] + continue + + if url in existing_links: + continue + + icon = {"title": "Pull request #%s" % pr['number'], + "url16x16": "https://assets-cdn.github.com/favicon.ico"} + destination = {"title": title, "url": url, "icon": icon} + # For all possible fields see: + # https://developer.atlassian.com/display/JIRADEV/Fields+in+Remote+Issue+Links + # application = {"name": "Github pull requests", "type": "org.apache.spark.jira.github"} + jira_client.add_remote_link(issue, destination) + + comment = "User '%s' has created a pull request for this issue:" % pr['user']['login'] + comment = comment + ("\n%s" % pr['html_url']) + if pr_num >= MIN_COMMENT_PR: + jira_client.add_comment(issue, comment) + + print "Added link %s <-> PR #%s" % (issue, pr['number']) + num_updates = num_updates + 1 + +if len(considered) > 0: + set_max_pr(max(considered)) diff --git a/dev/run-tests b/dev/run-tests index d9df020f7563c..51e4def0f835a 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -21,8 +21,7 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd $FWDIR -export SPARK_HADOOP_VERSION=2.3.0 -export SPARK_YARN=true +export SBT_MAVEN_PROFILES="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" # Remove work directory rm -rf ./work @@ -66,10 +65,10 @@ echo "=========================================================================" # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. if [ -n "$_RUN_SQL_TESTS" ]; then - echo -e "q\n" | SPARK_HIVE=true sbt/sbt clean assembly test | \ - grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + echo -e "q\n" | SBT_MAVEN_PROFILES="$SBT_MAVEN_PROFILES -Phive" sbt/sbt clean package \ + assembly/assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" else - echo -e "q\n" | sbt/sbt clean assembly test | \ + echo -e "q\n" | sbt/sbt clean package assembly/assembly test | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" fi diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins new file mode 100755 index 0000000000000..8dda671e976ce --- /dev/null +++ b/dev/run-tests-jenkins @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Wrapper script that runs the Spark tests then reports QA results +# to github via its API. + +# Go to the Spark project root directory +FWDIR="$(cd `dirname $0`/..; pwd)" +cd $FWDIR + +COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" + +function post_message { + message=$1 + data="{\"body\": \"$message\"}" + echo "Attempting to post to Github:" + echo "$data" + + curl -D- -u x-oauth-basic:$GITHUB_OAUTH_KEY -X POST --data "$data" -H \ + "Content-Type: application/json" \ + $COMMENTS_URL | head -n 8 +} + +start_message="QA tests have started for PR $ghprbPullId." +if [ "$sha1" == "$ghprbActualCommit" ]; then + start_message="$start_message This patch DID NOT merge cleanly! " +else + start_message="$start_message This patch merges cleanly. " +fi +start_message="$start_message
View progress: " +start_message="$start_message${BUILD_URL}consoleFull" + +post_message "$start_message" + +./dev/run-tests +test_result="$?" + +result_message="QA results for PR $ghprbPullId:
" + +if [ "$test_result" -eq "0" ]; then + result_message="$result_message- This patch PASSES unit tests.
" +else + result_message="$result_message- This patch FAILED unit tests.
" +fi + +if [ "$sha1" != "$ghprbActualCommit" ]; then + result_message="$result_message- This patch merges cleanly
" + non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") + new_public_classes=$(git diff master $non_test_files \ + | grep -e "trait " -e "class " \ + | grep -e "{" -e "(" \ + | grep -v -e \@\@ -e private \ + | grep \+ \ + | sed "s/\+ *//" \ + | tr "\n" "~" \ + | sed "s/~/
/g") + if [ "$new_public_classes" == "" ]; then + result_message="$result_message- This patch adds no public classes
" + else + result_message="$result_message- This patch adds the following public classes (experimental):
" + result_message="$result_message$new_public_classes" + fi +fi +result_message="${result_message}
For more information see test ouptut:" +result_message="${result_message}
${BUILD_URL}consoleFull" + +post_message "$result_message" + +exit $test_result diff --git a/dev/scalastyle b/dev/scalastyle index 0e8fd5cc8d64c..a02d06912f238 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,12 +17,12 @@ # limitations under the License. # -echo -e "q\n" | SPARK_HIVE=true sbt/sbt scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt # Check style with YARN alpha built too -echo -e "q\n" | SPARK_HADOOP_VERSION=0.23.9 SPARK_YARN=true sbt/sbt yarn-alpha/scalastyle \ +echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt # Check style with YARN built too -echo -e "q\n" | SPARK_HADOOP_VERSION=2.2.0 SPARK_YARN=true sbt/sbt yarn/scalastyle \ +echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \ >> scalastyle.txt ERRORS=$(cat scalastyle.txt | grep -e "\") diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index b280df0c8eeb8..7e55131754a3f 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -46,7 +46,7 @@ import org.apache.spark.bagel.Bagel._ Next, we load a sample graph from a text file as a distributed dataset and package it into `PRVertex` objects. We also cache the distributed dataset because Bagel will use it multiple times and we'd like to avoid recomputing it. {% highlight scala %} -val input = sc.textFile("data/pagerank_data.txt") +val input = sc.textFile("data/mllib/pagerank_data.txt") val numVerts = input.count() diff --git a/docs/configuration.md b/docs/configuration.md index b84104cc7e653..a70007c165442 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -336,13 +336,12 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec - org.apache.spark.io.
LZFCompressionCodec + org.apache.spark.io.
SnappyCompressionCodec The codec used to compress internal data such as RDD partitions and shuffle outputs. - By default, Spark provides two codecs: org.apache.spark.io.LZFCompressionCodec - and org.apache.spark.io.SnappyCompressionCodec. Of these two choices, - Snappy offers faster compression and decompression, while LZF offers a better compression - ratio. + By default, Spark provides three codecs: org.apache.spark.io.LZ4CompressionCodec, + org.apache.spark.io.LZFCompressionCodec, + and org.apache.spark.io.SnappyCompressionCodec. @@ -350,7 +349,15 @@ Apart from these, the following properties are also available, and may be useful 32768 Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec - is used. + is used. Lowering this block size will also lower shuffle memory usage when Snappy is used. + + + + spark.io.compression.lz4.block.size + 32768 + + Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used. @@ -412,7 +419,7 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.factory - org.apache.spark.broadcast.
HttpBroadcastFactory + org.apache.spark.broadcast.
TorrentBroadcastFactory Which broadcast implementation to use. @@ -699,6 +706,25 @@ Apart from these, the following properties are also available, and may be useful (in milliseconds) + + spark.scheduler.minRegisteredExecutorsRatio + 0 + + The minimum ratio of registered executors (registered executors / total expected executors) + to wait for before scheduling begins. Specified as a double between 0 and 1. + Regardless of whether the minimum ratio of executors has been reached, + the maximum amount of time it will wait before scheduling begins is controlled by config + spark.scheduler.maxRegisteredExecutorsWaitingTime + + + + spark.scheduler.maxRegisteredExecutorsWaitingTime + 30000 + + Maximum amount of time to wait for executors to register before scheduling begins + (in milliseconds). + + #### Security @@ -773,6 +799,15 @@ Apart from these, the following properties are also available, and may be useful into blocks of data before storing them in Spark. + + spark.streaming.receiver.maxRate + infinite + + Maximum rate (per second) at which each receiver will push data into blocks. Effectively, + each stream will consume at most this number of records per second. + Setting this configuration to 0 or a negative number will put no limit on the rate. + + spark.streaming.unpersist true diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index 32403bc6957a2..ab1023b8f1842 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -48,9 +48,9 @@ the _exact_ Hadoop version you are running to avoid any compatibility errors. -In SBT, the equivalent can be achieved by setting the SPARK_HADOOP_VERSION flag: +In SBT, the equivalent can be achieved by setting the the `hadoop.version` property: - SPARK_HADOOP_VERSION=1.0.4 sbt/sbt assembly + sbt/sbt -Dhadoop.version=1.0.4 assembly # Linking Applications to the Hadoop Version diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md index 5796e16e8f99c..f9585251fafac 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-basics.md @@ -193,7 +193,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -val examples: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") +val examples: RDD[LabeledPoint] = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") {% endhighlight %}
@@ -207,7 +207,7 @@ import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.api.java.JavaRDD; JavaRDD examples = - MLUtils.loadLibSVMFile(jsc.sc(), "mllib/data/sample_libsvm_data.txt").toJavaRDD(); + MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD(); {% endhighlight %}
@@ -218,7 +218,7 @@ examples stored in LIBSVM format. {% highlight python %} from pyspark.mllib.util import MLUtils -examples = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") +examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") {% endhighlight %} diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 429cdf8d40cec..c76ac010d3f81 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -51,7 +51,7 @@ import org.apache.spark.mllib.clustering.KMeans import org.apache.spark.mllib.linalg.Vectors // Load and parse the data -val data = sc.textFile("data/kmeans_data.txt") +val data = sc.textFile("data/mllib/kmeans_data.txt") val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) // Cluster the data into two classes using KMeans @@ -86,7 +86,7 @@ from numpy import array from math import sqrt # Load and parse the data -data = sc.textFile("data/kmeans_data.txt") +data = sc.textFile("data/mllib/kmeans_data.txt") parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')])) # Build the model (cluster the data) diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index d51002f015670..5cd71738722a9 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -58,7 +58,7 @@ import org.apache.spark.mllib.recommendation.ALS import org.apache.spark.mllib.recommendation.Rating // Load and parse the data -val data = sc.textFile("mllib/data/als/test.data") +val data = sc.textFile("data/mllib/als/test.data") val ratings = data.map(_.split(',') match { case Array(user, item, rate) => Rating(user.toInt, item.toInt, rate.toDouble) }) @@ -112,7 +112,7 @@ from pyspark.mllib.recommendation import ALS from numpy import array # Load and parse the data -data = sc.textFile("mllib/data/als/test.data") +data = sc.textFile("data/mllib/als/test.data") ratings = data.map(lambda line: array([float(x) for x in line.split(',')])) # Build the recommendation model using Alternating Least Squares diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 3002a66a4fdb3..9cbd880897578 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -77,15 +77,17 @@ bins if the condition is not satisfied. **Categorical features** -For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for -binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the +For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For +binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the categorical feature values by the proportion of labels falling in one of the two classes (see Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B -and A , B \| C where \| denotes the split. +and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification +when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value +is used for ordering. ### Stopping rule @@ -122,7 +124,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.Gini // Load and parse the data file -val data = sc.textFile("mllib/data/sample_tree_data.csv") +val data = sc.textFile("data/mllib/sample_tree_data.csv") val parsedData = data.map { line => val parts = line.split(',').map(_.toDouble) LabeledPoint(parts(0), Vectors.dense(parts.tail)) @@ -161,7 +163,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.Variance // Load and parse the data file -val data = sc.textFile("mllib/data/sample_tree_data.csv") +val data = sc.textFile("data/mllib/sample_tree_data.csv") val parsedData = data.map { line => val parts = line.split(',').map(_.toDouble) LabeledPoint(parts(0), Vectors.dense(parts.tail)) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 4dfbebbcd04b7..b4d22e0df5a85 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -187,7 +187,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. -val data = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Split data into training (60%) and test (40%). val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) @@ -259,7 +259,7 @@ def parsePoint(line): values = [float(x) for x in line.split(' ')] return LabeledPoint(values[0], values[1:]) -data = sc.textFile("mllib/data/sample_svm_data.txt") +data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model @@ -309,7 +309,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors // Load and parse the data -val data = sc.textFile("mllib/data/ridge-data/lpsa.data") +val data = sc.textFile("data/mllib/ridge-data/lpsa.data") val parsedData = data.map { line => val parts = line.split(',') LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) @@ -356,7 +356,7 @@ def parsePoint(line): values = [float(x) for x in line.replace(',', ' ').split(' ')] return LabeledPoint(values[0], values[1:]) -data = sc.textFile("mllib/data/ridge-data/lpsa.data") +data = sc.textFile("data/mllib/ridge-data/lpsa.data") parsedData = data.map(parsePoint) # Build the model diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 1d1d7dcf6ffcb..b1650c83c98b9 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -40,7 +40,7 @@ import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -val data = sc.textFile("mllib/data/sample_naive_bayes_data.txt") +val data = sc.textFile("data/mllib/sample_naive_bayes_data.txt") val parsedData = data.map { line => val parts = line.split(',') LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index ae9ede58e8e60..651958c7812f2 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -214,7 +214,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.classification.LogisticRegressionModel -val data = MLUtils.loadLibSVMFile(sc, "mllib/data/sample_libsvm_data.txt") +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val numFeatures = data.take(1)(0).features.size // Split data into training (60%) and test (40%). diff --git a/docs/programming-guide.md b/docs/programming-guide.md index b09d6347cd1b2..90c69713019f2 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -739,7 +739,7 @@ def doStuff(self, rdd): While most Spark operations work on RDDs containing any type of objects, a few special operations are only available on RDDs of key-value pairs. -The most common ones are distibuted "shuffle" operations, such as grouping or aggregating the elements +The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. In Scala, these operations are automatically available on RDDs containing @@ -773,7 +773,7 @@ documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#ha While most Spark operations work on RDDs containing any type of objects, a few special operations are only available on RDDs of key-value pairs. -The most common ones are distibuted "shuffle" operations, such as grouping or aggregating the elements +The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. In Java, key-value pairs are represented using the @@ -810,7 +810,7 @@ documentation](http://docs.oracle.com/javase/7/docs/api/java/lang/Object.html#ha While most Spark operations work on RDDs containing any type of objects, a few special operations are only available on RDDs of key-value pairs. -The most common ones are distibuted "shuffle" operations, such as grouping or aggregating the elements +The most common ones are distributed "shuffle" operations, such as grouping or aggregating the elements by a key. In Python, these operations work on RDDs containing built-in Python tuples such as `(1, 2)`. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index f5c0f7cef83d2..ad8b6c0e51a78 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -156,6 +156,20 @@ SPARK_MASTER_OPTS supports the following system properties: + + + + + + + + + + diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 522c83884ef42..38728534a46e0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -474,7 +474,7 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run '`SPARK_HIVE=true sbt/sbt assembly/assembly`' (or use `-Phive` for maven). +In order to use Hive you must first run '`sbt/sbt -Phive assembly/assembly`' (or use `-Phive` for maven). This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to acccess data stored in Hive. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 44775ea479ece..02cfe4ec39c7d 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -240,7 +240,10 @@ def get_spark_ami(opts): "r3.xlarge": "hvm", "r3.2xlarge": "hvm", "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm" + "r3.8xlarge": "hvm", + "t2.micro": "hvm", + "t2.small": "hvm", + "t2.medium": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] diff --git a/examples/pom.xml b/examples/pom.xml index 4f6d7fdb87d47..bd1c387c2eb91 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -27,6 +27,9 @@ org.apache.sparkspark-examples_2.10 + + examples + jarSpark Project Exampleshttp://spark.apache.org/ diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 4893b017ed819..822673347bdce 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -31,12 +31,12 @@ object HBaseTest { val conf = HBaseConfiguration.create() // Other options for configuring scan behavior are available. More information available at // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html - conf.set(TableInputFormat.INPUT_TABLE, args(1)) + conf.set(TableInputFormat.INPUT_TABLE, args(0)) // Initialize hBase table if necessary val admin = new HBaseAdmin(conf) - if(!admin.isTableAvailable(args(1))) { - val tableDesc = new HTableDescriptor(args(1)) + if (!admin.isTableAvailable(args(0))) { + val tableDesc = new HTableDescriptor(args(0)) admin.createTable(tableDesc) } diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index 331de3ad1ef53..ed2b38e2ca6f8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -19,16 +19,22 @@ package org.apache.spark.examples import org.apache.spark._ + object HdfsTest { + + /** Usage: HdfsTest [file] */ def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: HdfsTest ") + System.exit(1) + } val sparkConf = new SparkConf().setAppName("HdfsTest") val sc = new SparkContext(sparkConf) - val file = sc.textFile(args(1)) + val file = sc.textFile(args(0)) val mapped = file.map(s => s.length).cache() for (iter <- 1 to 10) { val start = System.currentTimeMillis() for (x <- mapped) { x + 2 } - // println("Processing: " + x) val end = System.currentTimeMillis() println("Iteration " + iter + " took " + (end-start) + " ms") } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index 4d28e0aad6597..79cfedf332436 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -17,8 +17,6 @@ package org.apache.spark.examples -import java.util.Random - import breeze.linalg.{Vector, DenseVector, squaredDistance} import org.apache.spark.{SparkConf, SparkContext} @@ -28,15 +26,12 @@ import org.apache.spark.SparkContext._ * K-means clustering. */ object SparkKMeans { - val R = 1000 // Scaling factor - val rand = new Random(42) def parseVector(line: String): Vector[Double] = { DenseVector(line.split(' ').map(_.toDouble)) } def closestPoint(p: Vector[Double], centers: Array[Vector[Double]]): Int = { - var index = 0 var bestIndex = 0 var closest = Double.PositiveInfinity diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 40b36c779afd6..4c7e006da0618 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -31,8 +31,12 @@ import org.apache.spark.{SparkConf, SparkContext} */ object SparkPageRank { def main(args: Array[String]) { + if (args.length < 1) { + System.err.println("Usage: SparkPageRank ") + System.exit(1) + } val sparkConf = new SparkConf().setAppName("PageRank") - var iters = args(1).toInt + val iters = if (args.length > 0) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b3cc361154198..43f13fe24f0d0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -49,6 +49,7 @@ object DecisionTreeRunner { case class Params( input: String = null, algo: Algo = Classification, + numClassesForClassification: Int = 2, maxDepth: Int = 5, impurity: ImpurityType = Gini, maxBins: Int = 100) @@ -68,6 +69,10 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numClassesForClassification") + .text(s"number of classes for classification, " + + s"default: ${defaultParams.numClassesForClassification}") + .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) @@ -118,7 +123,13 @@ object DecisionTreeRunner { case Variance => impurity.Variance } - val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins) + val strategy + = new Strategy( + algo = params.algo, + impurity = impurityCalculator, + maxDepth = params.maxDepth, + maxBins = params.maxBins, + numClassesForClassification = params.numClassesForClassification) val model = DecisionTree.train(training, strategy) if (params.algo == Classification) { @@ -139,12 +150,8 @@ object DecisionTreeRunner { */ private def accuracyScore( model: DecisionTreeModel, - data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { - def predictedValue(features: Vector): Double = { - if (model.predict(features) < threshold) 0.0 else 1.0 - } - val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + data: RDD[LabeledPoint]): Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() correctCount.toDouble / count } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index b262fabbe0e0d..66a23fac39999 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -42,7 +42,7 @@ object HiveFromSpark { hql("SELECT * FROM src").collect.foreach(println) // Aggregation queries are also supported. - val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0) + val count = hql("SELECT COUNT(*) FROM src").collect().head.getLong(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The diff --git a/external/flume/pom.xml b/external/flume/pom.xml index c1f581967777b..61a6aff543aed 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-streaming-flume_2.10 + + streaming-flume + jar Spark Project External Flume http://spark.apache.org/ diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index ed35e34ad45ab..56d2886b26878 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer +import java.util.concurrent.Executors import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -29,24 +30,32 @@ import org.apache.flume.source.avro.AvroFlumeEvent import org.apache.flume.source.avro.Status import org.apache.avro.ipc.specific.SpecificResponder import org.apache.avro.ipc.NettyServer - +import org.apache.spark.Logging import org.apache.spark.util.Utils import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ -import org.apache.spark.Logging +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver +import org.jboss.netty.channel.ChannelPipelineFactory +import org.jboss.netty.channel.Channels +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.ChannelFactory +import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory +import org.jboss.netty.handler.codec.compression._ +import org.jboss.netty.handler.execution.ExecutionHandler + private[streaming] class FlumeInputDStream[T: ClassTag]( @transient ssc_ : StreamingContext, host: String, port: Int, - storageLevel: StorageLevel + storageLevel: StorageLevel, + enableDecompression: Boolean ) extends ReceiverInputDStream[SparkFlumeEvent](ssc_) { override def getReceiver(): Receiver[SparkFlumeEvent] = { - new FlumeReceiver(host, port, storageLevel) + new FlumeReceiver(host, port, storageLevel, enableDecompression) } } @@ -134,22 +143,71 @@ private[streaming] class FlumeReceiver( host: String, port: Int, - storageLevel: StorageLevel + storageLevel: StorageLevel, + enableDecompression: Boolean ) extends Receiver[SparkFlumeEvent](storageLevel) with Logging { lazy val responder = new SpecificResponder( classOf[AvroSourceProtocol], new FlumeEventServer(this)) - lazy val server = new NettyServer(responder, new InetSocketAddress(host, port)) + var server: NettyServer = null + + private def initServer() = { + if (enableDecompression) { + val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), + Executors.newCachedThreadPool()) + val channelPipelineFactory = new CompressionChannelPipelineFactory() + + new NettyServer( + responder, + new InetSocketAddress(host, port), + channelFactory, + channelPipelineFactory, + null) + } else { + new NettyServer(responder, new InetSocketAddress(host, port)) + } + } def onStart() { - server.start() + synchronized { + if (server == null) { + server = initServer() + server.start() + } else { + logWarning("Flume receiver being asked to start more then once with out close") + } + } logInfo("Flume receiver started") } def onStop() { - server.close() + synchronized { + if (server != null) { + server.close() + server = null + } + } logInfo("Flume receiver stopped") } override def preferredLocation = Some(host) + + /** A Netty Pipeline factory that will decompress incoming data from + * and the Netty client and compress data going back to the client. + * + * The compression on the return is required because Flume requires + * a successful response to indicate it can remove the event/batch + * from the configured channel + */ + private[streaming] + class CompressionChannelPipelineFactory extends ChannelPipelineFactory { + + def getPipeline() = { + val pipeline = Channels.pipeline() + val encoder = new ZlibEncoder(6) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + pipeline + } +} } diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 499f3560ef768..716db9fa76031 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -36,7 +36,27 @@ object FlumeUtils { port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): ReceiverInputDStream[SparkFlumeEvent] = { - val inputStream = new FlumeInputDStream[SparkFlumeEvent](ssc, hostname, port, storageLevel) + createStream(ssc, hostname, port, storageLevel, false) + } + + /** + * Create a input stream from a Flume source. + * @param ssc StreamingContext object + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + * @param storageLevel Storage level to use for storing the received objects + * @param enableDecompression should netty server decompress input stream + */ + def createStream ( + ssc: StreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): ReceiverInputDStream[SparkFlumeEvent] = { + val inputStream = new FlumeInputDStream[SparkFlumeEvent]( + ssc, hostname, port, storageLevel, enableDecompression) + inputStream } @@ -66,6 +86,23 @@ object FlumeUtils { port: Int, storageLevel: StorageLevel ): JavaReceiverInputDStream[SparkFlumeEvent] = { - createStream(jssc.ssc, hostname, port, storageLevel) + createStream(jssc.ssc, hostname, port, storageLevel, false) + } + + /** + * Creates a input stream from a Flume source. + * @param hostname Hostname of the slave machine to which the flume data will be sent + * @param port Port of the slave machine to which the flume data will be sent + * @param storageLevel Storage level to use for storing the received objects + * @param enableDecompression should netty server decompress input stream + */ + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaReceiverInputDStream[SparkFlumeEvent] = { + createStream(jssc.ssc, hostname, port, storageLevel, enableDecompression) } } diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java index e0ad4f1015205..3b5e0c7746b2c 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java @@ -30,5 +30,7 @@ public void testFlumeStream() { JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", 12345); JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", 12345, + StorageLevel.MEMORY_AND_DISK_SER_2(), false); } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index dd287d0ef90a0..73dffef953309 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -33,15 +33,26 @@ import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuite import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.api.java.JavaReceiverInputDStream -class FlumeStreamSuite extends TestSuiteBase { +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.handler.codec.compression._ - val testPort = 9999 +class FlumeStreamSuite extends TestSuiteBase { test("flume input stream") { + runFlumeStreamTest(false, 9998) + } + + test("flume input compressed stream") { + runFlumeStreamTest(true, 9997) + } + + def runFlumeStreamTest(enableDecompression: Boolean, testPort: Int) { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val flumeStream: JavaReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createStream(ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK) + FlumeUtils.createStream(ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, enableDecompression) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream.receiverInputDStream, outputBuffer) @@ -52,8 +63,17 @@ class FlumeStreamSuite extends TestSuiteBase { val input = Seq(1, 2, 3, 4, 5) Thread.sleep(1000) val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)) - val client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], transceiver) + var client: AvroSourceProtocol = null; + + if (enableDecompression) { + client = SpecificRequestor.getClient( + classOf[AvroSourceProtocol], + new NettyTransceiver(new InetSocketAddress("localhost", testPort), + new CompressionChannelFactory(6))); + } else { + client = SpecificRequestor.getClient( + classOf[AvroSourceProtocol], transceiver) + } for (i <- 0 until input.size) { val event = new AvroFlumeEvent @@ -64,6 +84,8 @@ class FlumeStreamSuite extends TestSuiteBase { clock.addToTime(batchDuration.milliseconds) } + Thread.sleep(1000) + val startTime = System.currentTimeMillis() while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) @@ -85,4 +107,13 @@ class FlumeStreamSuite extends TestSuiteBase { assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") } } + + class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + override def newChannel(pipeline:ChannelPipeline) : SocketChannel = { + var encoder : ZlibEncoder = new ZlibEncoder(compressionLevel); + pipeline.addFirst("deflater", encoder); + pipeline.addFirst("inflater", new ZlibDecoder()); + super.newChannel(pipeline); + } + } } diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index d014a7aad0fca..4762c50685a93 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-streaming-kafka_2.10 + + streaming-kafka + jar Spark Project External Kafka http://spark.apache.org/ diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 4980208cba3b0..32c530e600ce0 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-streaming-mqtt_2.10 + + streaming-mqtt + jar Spark Project External MQTT http://spark.apache.org/ diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 7073bd4404d9c..637adb0f00da0 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-streaming-twitter_2.10 + + streaming-twitter + jar Spark Project External Twitter http://spark.apache.org/ diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index cf306e0dca8bd..e4d758a04a4cd 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-streaming-zeromq_2.10 + + streaming-zeromq + jar Spark Project External ZeroMQ http://spark.apache.org/ diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 955ec1a8c3033..3eade411b38b7 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -28,7 +28,11 @@ java8-tests_2.10 pom Spark Project Java8 Tests POM - + + + java8-tests + + org.apache.spark diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 22ea330b4374d..a5b162a0482e4 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -29,7 +29,11 @@ spark-ganglia-lgpl_2.10 jar Spark Ganglia Integration - + + + ganglia-lgpl + + org.apache.spark diff --git a/graphx/pom.xml b/graphx/pom.xml index 7d5d83e7f3bb9..7e3bcf29dcfbc 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-graphx_2.10 + + graphx + jar Spark Project GraphX http://spark.apache.org/ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 4db45c9af8fae..3507f358bfb40 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -107,14 +107,16 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab /** * Repartitions the edges in the graph according to `partitionStrategy`. * - * @param the partitioning strategy to use when partitioning the edges in the graph. + * @param partitionStrategy the partitioning strategy to use when partitioning the edges + * in the graph. */ def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] /** * Repartitions the edges in the graph according to `partitionStrategy`. * - * @param the partitioning strategy to use when partitioning the edges in the graph. + * @param partitionStrategy the partitioning strategy to use when partitioning the edges + * in the graph. * @param numPartitions the number of edge partitions in the new graph. */ def partitionBy(partitionStrategy: PartitionStrategy, numPartitions: Int): Graph[VD, ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index f1b6df9a3025e..4825d12fc27b3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -182,8 +182,8 @@ class VertexRDD[@specialized VD: ClassTag]( /** * Left joins this RDD with another VertexRDD with the same index. This function will fail if * both VertexRDDs do not share the same index. The resulting vertex set contains an entry for - * each - * vertex in `this`. If `other` is missing any vertex in this VertexRDD, `f` is passed `None`. + * each vertex in `this`. + * If `other` is missing any vertex in this VertexRDD, `f` is passed `None`. * * @tparam VD2 the attribute type of the other VertexRDD * @tparam VD3 the attribute type of the resulting VertexRDD diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 3827ac8d0fd6a..502b112d31c2e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -119,7 +119,7 @@ object RoutingTablePartition { */ private[graphx] class RoutingTablePartition( - private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) { + private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) extends Serializable { /** The maximum number of edge partitions this `RoutingTablePartition` is built to join with. */ val numEdgePartitions: Int = routingTable.size diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala index 34939b24440aa..5ad6390a56c4f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala @@ -60,7 +60,8 @@ private[graphx] object VertexPartitionBase { * `VertexPartitionBaseOpsConstructor` typeclass (for example, * [[VertexPartition.VertexPartitionOpsConstructor]]). */ -private[graphx] abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag] { +private[graphx] abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag] + extends Serializable { def index: VertexIdToIndexMap def values: Array[VD] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index a4f769b294010..b40aa1b417a0f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -35,7 +35,7 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap private[graphx] abstract class VertexPartitionBaseOps [VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor] (self: Self[VD]) - extends Logging { + extends Serializable with Logging { def withIndex(index: VertexIdToIndexMap): Self[VD] def withValues[VD2: ClassTag](values: Array[VD2]): Self[VD2] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 28fd112f2b124..9d00f76327e4c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -23,6 +23,7 @@ import scala.util.Random import org.scalatest.FunSuite import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ @@ -124,18 +125,21 @@ class EdgePartitionSuite extends FunSuite { assert(ep.numActives == Some(2)) } - test("Kryo serialization") { + test("serialization") { val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) - val conf = new SparkConf() + val javaSer = new JavaSerializer(new SparkConf()) + val kryoSer = new KryoSerializer(new SparkConf() .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") - val s = new KryoSerializer(conf).newInstance() - val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) - assert(aSer.srcIds.toList === a.srcIds.toList) - assert(aSer.dstIds.toList === a.dstIds.toList) - assert(aSer.data.toList === a.data.toList) - assert(aSer.index != null) - assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet) + .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")) + + for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { + val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) + assert(aSer.srcIds.toList === a.srcIds.toList) + assert(aSer.dstIds.toList === a.dstIds.toList) + assert(aSer.data.toList === a.data.toList) + assert(aSer.index != null) + assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet) + } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index 8bf1384d514c1..f9e771a900013 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark.graphx.impl -import org.apache.spark.graphx._ import org.scalatest.FunSuite +import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.KryoSerializer + +import org.apache.spark.graphx._ + class VertexPartitionSuite extends FunSuite { test("isDefined, filter") { @@ -116,4 +121,17 @@ class VertexPartitionSuite extends FunSuite { assert(vp3.index.getPos(2) === -1) } + test("serialization") { + val verts = Set((0L, 1), (1L, 1), (2L, 1)) + val vp = VertexPartition(verts.iterator) + val javaSer = new JavaSerializer(new SparkConf()) + val kryoSer = new KryoSerializer(new SparkConf() + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")) + + for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { + val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp)) + assert(vpSer.iterator.toSet === verts) + } + } } diff --git a/make-distribution.sh b/make-distribution.sh index 94b473bf91cd3..c08093f46b61f 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -23,21 +23,6 @@ # The distribution contains fat (assembly) jars that include the Scala library, # so it is completely self contained. # It does not contain source or *.class files. -# -# Optional Arguments -# --tgz: Additionally creates spark-$VERSION-bin.tar.gz -# --hadoop VERSION: Builds against specified version of Hadoop. -# --with-yarn: Enables support for Hadoop YARN. -# --with-hive: Enable support for reading Hive tables. -# --name: A moniker for the release target. Defaults to the Hadoop verison. -# -# Recommended deploy/testing procedure (standalone mode): -# 1) Rsync / deploy the dist/ dir to one host -# 2) cd to deploy dir; ./sbin/start-master.sh -# 3) Verify master is up by visiting web page, ie http://master-ip:8080. Note the spark:// URL. -# 4) ./sbin/start-slave.sh 1 <> -# 5) ./bin/spark-shell --master spark://my-master-ip:7077 -# set -o pipefail set -e @@ -46,26 +31,35 @@ set -e FWDIR="$(cd `dirname $0`; pwd)" DISTDIR="$FWDIR/dist" -# Initialize defaults -SPARK_HADOOP_VERSION=1.0.4 -SPARK_YARN=false -SPARK_HIVE=false SPARK_TACHYON=false MAKE_TGZ=false NAME=none +function exit_with_usage { + echo "make-distribution.sh - tool for making binary distributions of Spark" + echo "" + echo "usage:" + echo "./make-distribution.sh [--name] [--tgz] [--with-tachyon] " + echo "See Spark's \"Building with Maven\" doc for correct Maven options." + echo "" + exit 1 +} + # Parse arguments while (( "$#" )); do case $1 in --hadoop) - SPARK_HADOOP_VERSION="$2" - shift + echo "Error: '--hadoop' is no longer supported:" + echo "Error: use Maven options -Phadoop.version and -Pyarn.version" + exit_with_usage ;; --with-yarn) - SPARK_YARN=true + echo "Error: '--with-yarn' is no longer supported, use Maven option -Pyarn" + exit_with_usage ;; --with-hive) - SPARK_HIVE=true + echo "Error: '--with-hive' is no longer supported, use Maven option -Phive" + exit_with_usage ;; --skip-java-test) SKIP_JAVA_TEST=true @@ -80,6 +74,12 @@ while (( "$#" )); do NAME="$2" shift ;; + --help) + exit_with_usage + ;; + *) + break + ;; esac shift done @@ -143,14 +143,6 @@ else echo "Making distribution for Spark $VERSION in $DISTDIR..." fi -echo "Hadoop version set to $SPARK_HADOOP_VERSION" -echo "Release name set to $NAME" -if [ "$SPARK_YARN" == "true" ]; then - echo "YARN enabled" -else - echo "YARN disabled" -fi - if [ "$SPARK_TACHYON" == "true" ]; then echo "Tachyon Enabled" else @@ -162,33 +154,12 @@ cd $FWDIR export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -BUILD_COMMAND="mvn clean package" - -# Use special profiles for hadoop versions 0.23.x, 2.2.x, 2.3.x, 2.4.x -if [[ "$SPARK_HADOOP_VERSION" =~ ^0\.23\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-0.23"; fi -if [[ "$SPARK_HADOOP_VERSION" =~ ^2\.2\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-2.2"; fi -if [[ "$SPARK_HADOOP_VERSION" =~ ^2\.3\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-2.3"; fi -if [[ "$SPARK_HADOOP_VERSION" =~ ^2\.4\. ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phadoop-2.4"; fi -if [[ "$SPARK_HIVE" == "true" ]]; then BUILD_COMMAND="$BUILD_COMMAND -Phive"; fi -if [[ "$SPARK_YARN" == "true" ]]; then - # For hadoop versions 0.23.x to 2.1.x, use the yarn-alpha profile - if [[ "$SPARK_HADOOP_VERSION" =~ ^0\.2[3-9]\. ]] || - [[ "$SPARK_HADOOP_VERSION" =~ ^0\.[3-9][0-9]\. ]] || - [[ "$SPARK_HADOOP_VERSION" =~ ^1\.[0-9]\. ]] || - [[ "$SPARK_HADOOP_VERSION" =~ ^2\.[0-1]\. ]]; then - BUILD_COMMAND="$BUILD_COMMAND -Pyarn-alpha" - # For hadoop versions 2.2+, use the yarn profile - elif [[ "$SPARK_HADOOP_VERSION" =~ ^2.[2-9]. ]]; then - BUILD_COMMAND="$BUILD_COMMAND -Pyarn" - fi - BUILD_COMMAND="$BUILD_COMMAND -Dyarn.version=$SPARK_HADOOP_VERSION" -fi -BUILD_COMMAND="$BUILD_COMMAND -Dhadoop.version=$SPARK_HADOOP_VERSION" -BUILD_COMMAND="$BUILD_COMMAND -DskipTests" +BUILD_COMMAND="mvn clean package -DskipTests $@" # Actually build the jar echo -e "\nBuilding with..." echo -e "\$ $BUILD_COMMAND\n" + ${BUILD_COMMAND} # Make directories diff --git a/mllib/pom.xml b/mllib/pom.xml index b622f96dd7901..92b07e2357db1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-mllib_2.10 + + mllib + jar Spark Project ML Library http://spark.apache.org/ @@ -75,6 +78,19 @@ test + + + netlib-lgpl + + + com.github.fommil.netlib + all + 1.1.2 + pom + + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index 2e3a4ce783de7..f0722d7c14a46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -59,7 +59,13 @@ private[mllib] object LocalKMeans extends Logging { cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) j += 1 } - centers(i) = points(j-1).toDense + if (j == 0) { + logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." + + s" Using duplicate point for center k = $i.") + centers(i) = points(0).toDense + } else { + centers(i) = points(j - 1).toDense + } } // Run up to maxIterations iterations of Lloyd's algorithm diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala new file mode 100644 index 0000000000000..666362ae6739a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import scala.collection.Map + +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Matrices, Matrix} +import org.apache.spark.rdd.RDD + +/** + * ::Experimental:: + * Evaluator for multiclass classification. + * + * @param predictionAndLabels an RDD of (prediction, label) pairs. + */ +@Experimental +class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) { + + private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue() + private lazy val labelCount: Long = labelCountByClass.values.sum + private lazy val tpByClass: Map[Double, Int] = predictionAndLabels + .map { case (prediction, label) => + (label, if (label == prediction) 1 else 0) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val fpByClass: Map[Double, Int] = predictionAndLabels + .map { case (prediction, label) => + (prediction, if (prediction != label) 1 else 0) + }.reduceByKey(_ + _) + .collectAsMap() + private lazy val confusions = predictionAndLabels + .map { case (prediction, label) => + ((label, prediction), 1) + }.reduceByKey(_ + _) + .collectAsMap() + + /** + * Returns confusion matrix: + * predicted classes are in columns, + * they are ordered by class label ascending, + * as in "labels" + */ + def confusionMatrix: Matrix = { + val n = labels.size + val values = Array.ofDim[Double](n * n) + var i = 0 + while (i < n) { + var j = 0 + while (j < n) { + values(i + j * n) = confusions.getOrElse((labels(i), labels(j)), 0).toDouble + j += 1 + } + i += 1 + } + Matrices.dense(n, n, values) + } + + /** + * Returns true positive rate for a given label (category) + * @param label the label. + */ + def truePositiveRate(label: Double): Double = recall(label) + + /** + * Returns false positive rate for a given label (category) + * @param label the label. + */ + def falsePositiveRate(label: Double): Double = { + val fp = fpByClass.getOrElse(label, 0) + fp.toDouble / (labelCount - labelCountByClass(label)) + } + + /** + * Returns precision for a given label (category) + * @param label the label. + */ + def precision(label: Double): Double = { + val tp = tpByClass(label) + val fp = fpByClass.getOrElse(label, 0) + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + } + + /** + * Returns recall for a given label (category) + * @param label the label. + */ + def recall(label: Double): Double = tpByClass(label).toDouble / labelCountByClass(label) + + /** + * Returns f-measure for a given label (category) + * @param label the label. + * @param beta the beta parameter. + */ + def fMeasure(label: Double, beta: Double): Double = { + val p = precision(label) + val r = recall(label) + val betaSqrd = beta * beta + if (p + r == 0) 0 else (1 + betaSqrd) * p * r / (betaSqrd * p + r) + } + + /** + * Returns f1-measure for a given label (category) + * @param label the label. + */ + def fMeasure(label: Double): Double = fMeasure(label, 1.0) + + /** + * Returns precision + */ + lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount + + /** + * Returns recall + * (equals to precision for multiclass classifier + * because sum of all false positives is equal to sum + * of all false negatives) + */ + lazy val recall: Double = precision + + /** + * Returns f-measure + * (equals to precision and recall because precision equals recall) + */ + lazy val fMeasure: Double = precision + + /** + * Returns weighted true positive rate + * (equals to precision, recall and f-measure) + */ + lazy val weightedTruePositiveRate: Double = weightedRecall + + /** + * Returns weighted false positive rate + */ + lazy val weightedFalsePositiveRate: Double = labelCountByClass.map { case (category, count) => + falsePositiveRate(category) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged recall + * (equals to precision, recall and f-measure) + */ + lazy val weightedRecall: Double = labelCountByClass.map { case (category, count) => + recall(category) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged precision + */ + lazy val weightedPrecision: Double = labelCountByClass.map { case (category, count) => + precision(category) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged f-measure + * @param beta the beta parameter. + */ + def weightedFMeasure(beta: Double): Double = labelCountByClass.map { case (category, count) => + fMeasure(category, beta) * count.toDouble / labelCount + }.sum + + /** + * Returns weighted averaged f1-measure + */ + lazy val weightedFMeasure: Double = labelCountByClass.map { case (category, count) => + fMeasure(category, 1.0) * count.toDouble / labelCount + }.sum + + /** + * Returns the sequence of labels in ascending order + */ + lazy val labels: Array[Double] = tpByClass.keys.toArray.sorted +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala new file mode 100644 index 0000000000000..3515461b52493 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} +import com.github.fommil.netlib.ARPACK +import org.netlib.util.{intW, doubleW} + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Compute eigen-decomposition. + */ +@Experimental +private[mllib] object EigenValueDecomposition { + /** + * Compute the leading k eigenvalues and eigenvectors on a symmetric square matrix using ARPACK. + * The caller needs to ensure that the input matrix is real symmetric. This function requires + * memory for `n*(4*k+4)` doubles. + * + * @param mul a function that multiplies the symmetric matrix with a DenseVector. + * @param n dimension of the square matrix (maximum Int.MaxValue). + * @param k number of leading eigenvalues required, 0 < k < n. + * @param tol tolerance of the eigs computation. + * @param maxIterations the maximum number of Arnoldi update iterations. + * @return a dense vector of eigenvalues in descending order and a dense matrix of eigenvectors + * (columns of the matrix). + * @note The number of computed eigenvalues might be smaller than k when some Ritz values do not + * satisfy the convergence criterion specified by tol (see ARPACK Users Guide, Chapter 4.6 + * for more details). The maximum number of Arnoldi update iterations is set to 300 in this + * function. + */ + private[mllib] def symmetricEigs( + mul: BDV[Double] => BDV[Double], + n: Int, + k: Int, + tol: Double, + maxIterations: Int): (BDV[Double], BDM[Double]) = { + // TODO: remove this function and use eigs in breeze when switching breeze version + require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n") + + val arpack = ARPACK.getInstance() + + // tolerance used in stopping criterion + val tolW = new doubleW(tol) + // number of desired eigenvalues, 0 < nev < n + val nev = new intW(k) + // nev Lanczos vectors are generated in the first iteration + // ncv-nev Lanczos vectors are generated in each subsequent iteration + // ncv must be smaller than n + val ncv = math.min(2 * k, n) + + // "I" for standard eigenvalue problem, "G" for generalized eigenvalue problem + val bmat = "I" + // "LM" : compute the NEV largest (in magnitude) eigenvalues + val which = "LM" + + var iparam = new Array[Int](11) + // use exact shift in each iteration + iparam(0) = 1 + // maximum number of Arnoldi update iterations, or the actual number of iterations on output + iparam(2) = maxIterations + // Mode 1: A*x = lambda*x, A symmetric + iparam(6) = 1 + + var ido = new intW(0) + var info = new intW(0) + var resid = new Array[Double](n) + var v = new Array[Double](n * ncv) + var workd = new Array[Double](n * 3) + var workl = new Array[Double](ncv * (ncv + 8)) + var ipntr = new Array[Int](11) + + // call ARPACK's reverse communication, first iteration with ido = 0 + arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, workd, + workl, workl.length, info) + + val w = BDV(workd) + + // ido = 99 : done flag in reverse communication + while (ido.`val` != 99) { + if (ido.`val` != -1 && ido.`val` != 1) { + throw new IllegalStateException("ARPACK returns ido = " + ido.`val` + + " This flag is not compatible with Mode 1: A*x = lambda*x, A symmetric.") + } + // multiply working vector with the matrix + val inputOffset = ipntr(0) - 1 + val outputOffset = ipntr(1) - 1 + val x = w.slice(inputOffset, inputOffset + n) + val y = w.slice(outputOffset, outputOffset + n) + y := mul(x) + // call ARPACK's reverse communication + arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr, + workd, workl, workl.length, info) + } + + if (info.`val` != 0) { + info.`val` match { + case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + + " Maximum number of iterations taken. (Refer ARPACK user guide for details)") + case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + + " No shifts could be applied. Try to increase NCV. " + + "(Refer ARPACK user guide for details)") + case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` + + " Please refer ARPACK user guide for error message.") + } + } + + val d = new Array[Double](nev.`val`) + val select = new Array[Boolean](ncv) + // copy the Ritz vectors + val z = java.util.Arrays.copyOfRange(v, 0, nev.`val` * n) + + // call ARPACK's post-processing for eigenvectors + arpack.dseupd(true, "A", select, d, z, n, 0.0, bmat, n, which, nev, tol, resid, ncv, v, n, + iparam, ipntr, workd, workl, workl.length, info) + + // number of computed eigenvalues, might be smaller than k + val computed = iparam(4) + + val eigenPairs = java.util.Arrays.copyOfRange(d, 0, computed).zipWithIndex.map { r => + (r._1, java.util.Arrays.copyOfRange(z, r._2 * n, r._2 * n + n)) + } + + // sort the eigen-pairs in descending order + val sortedEigenPairs = eigenPairs.sortBy(- _._1) + + // copy eigenvectors in descending order of eigenvalues + val sortedU = BDM.zeros[Double](n, computed) + sortedEigenPairs.zipWithIndex.foreach { r => + val b = r._2 * n + var i = 0 + while (i < n) { + sortedU.data(b + i) = r._1._2(i) + i += 1 + } + } + + (BDV[Double](sortedEigenPairs.map(_._1)), sortedU) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c818a0b9c3e43..77b3e8c714997 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -62,7 +62,7 @@ trait Vector extends Serializable { * Gets the value of the ith element. * @param i index */ - private[mllib] def apply(i: Int): Double = toBreeze(i) + def apply(i: Int): Double = toBreeze(i) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 695e03b736baf..f4c403bc7861c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.linalg.distributed -import java.util +import java.util.Arrays -import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd} +import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -27,138 +28,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary - -/** - * Column statistics aggregator implementing - * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]] - * together with add() and merge() function. - * A numerically stable algorithm is implemented to compute sample mean and variance: - *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. - * Zero elements (including explicit zero values) are skipped when calling add() and merge(), - * to have time complexity O(nnz) instead of O(n) for each column. - */ -private class ColumnStatisticsAggregator(private val n: Int) - extends MultivariateStatisticalSummary with Serializable { - - private val currMean: BDV[Double] = BDV.zeros[Double](n) - private val currM2n: BDV[Double] = BDV.zeros[Double](n) - private var totalCnt = 0.0 - private val nnz: BDV[Double] = BDV.zeros[Double](n) - private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue) - private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue) - - override def mean: Vector = { - val realMean = BDV.zeros[Double](n) - var i = 0 - while (i < n) { - realMean(i) = currMean(i) * nnz(i) / totalCnt - i += 1 - } - Vectors.fromBreeze(realMean) - } - - override def variance: Vector = { - val realVariance = BDV.zeros[Double](n) - - val denominator = totalCnt - 1.0 - - // Sample variance is computed, if the denominator is less than 0, the variance is just 0. - if (denominator > 0.0) { - val deltaMean = currMean - var i = 0 - while (i < currM2n.size) { - realVariance(i) = - currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt - realVariance(i) /= denominator - i += 1 - } - } - - Vectors.fromBreeze(realVariance) - } - - override def count: Long = totalCnt.toLong - - override def numNonzeros: Vector = Vectors.fromBreeze(nnz) - - override def max: Vector = { - var i = 0 - while (i < n) { - if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 - i += 1 - } - Vectors.fromBreeze(currMax) - } - - override def min: Vector = { - var i = 0 - while (i < n) { - if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 - i += 1 - } - Vectors.fromBreeze(currMin) - } - - /** - * Aggregates a row. - */ - def add(currData: BV[Double]): this.type = { - currData.activeIterator.foreach { - case (_, 0.0) => // Skip explicit zero elements. - case (i, value) => - if (currMax(i) < value) { - currMax(i) = value - } - if (currMin(i) > value) { - currMin(i) = value - } - - val tmpPrevMean = currMean(i) - currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) - - nnz(i) += 1.0 - } - - totalCnt += 1.0 - this - } - - /** - * Merges another aggregator. - */ - def merge(other: ColumnStatisticsAggregator): this.type = { - require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.") - - totalCnt += other.totalCnt - val deltaMean = currMean - other.currMean - - var i = 0 - while (i < n) { - // merge mean together - if (other.currMean(i) != 0.0) { - currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / - (nnz(i) + other.nnz(i)) - } - // merge m2n together - if (nnz(i) + other.nnz(i) != 0.0) { - currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / - (nnz(i) + other.nnz(i)) - } - if (currMax(i) < other.currMax(i)) { - currMax(i) = other.currMax(i) - } - if (currMin(i) > other.currMin(i)) { - currMin(i) = other.currMin(i) - } - i += 1 - } - - nnz += other.nnz - this - } -} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} /** * :: Experimental :: @@ -200,6 +70,32 @@ class RowMatrix( nRows } + /** + * Multiplies the Gramian matrix `A^T A` by a dense vector on the right without computing `A^T A`. + * + * @param v a dense vector whose length must match the number of columns of this matrix + * @return a dense vector representing the product + */ + private[mllib] def multiplyGramianMatrixBy(v: BDV[Double]): BDV[Double] = { + val n = numCols().toInt + val vbr = rows.context.broadcast(v) + rows.aggregate(BDV.zeros[Double](n))( + seqOp = (U, r) => { + val rBrz = r.toBreeze + val a = rBrz.dot(vbr.value) + rBrz match { + // use specialized axpy for better performance + case _: BDV[_] => brzAxpy(a, rBrz.asInstanceOf[BDV[Double]], U) + case _: BSV[_] => brzAxpy(a, rBrz.asInstanceOf[BSV[Double]], U) + case _ => throw new UnsupportedOperationException( + s"Do not support vector operation from type ${rBrz.getClass.getName}.") + } + U + }, + combOp = (U1, U2) => U1 += U2 + ) + } + /** * Computes the Gramian matrix `A^T A`. */ @@ -220,50 +116,135 @@ class RowMatrix( } /** - * Computes the singular value decomposition of this matrix. - * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. + * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This + * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k + * singular values, U and V contain the corresponding singular vectors. * - * There is no restriction on m, but we require `n^2` doubles to fit in memory. - * Further, n should be less than m. - - * The decomposition is computed by first computing A'A = V S^2 V', - * computing svd locally on that (since n x n is small), from which we recover S and V. - * Then we compute U via easy matrix multiplication as U = A * (V * S^-1). - * Note that this approach requires `O(n^3)` time on the master node. + * At most k largest non-zero singular values and associated vectors are returned. If there are k + * such values, then the dimensions of the return will be: + * - U is a RowMatrix of size m x k that satisfies U' * U = eye(k), + * - s is a Vector of size k, holding the singular values in descending order, + * - V is a Matrix of size n x k that satisfies V' * V = eye(k). + * + * We assume n is smaller than m. The singular values and the right singular vectors are derived + * from the eigenvalues and the eigenvectors of the Gramian matrix A' * A. U, the matrix + * storing the right singular vectors, is computed via matrix multiplication as + * U = A * (V * S^-1^), if requested by user. The actual method to use is determined + * automatically based on the cost: + * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian + * matrix first and then compute its top eigenvalues and eigenvectors locally on the driver. + * This requires a single pass with O(n^2^) storage on each executor and on the driver, and + * O(n^2^ k) time on the driver. + * - Otherwise, we compute (A' * A) * v in a distributive way and send it to ARPACK's DSAUPD to + * compute (A' * A)'s top eigenvalues and eigenvectors on the driver node. This requires O(k) + * passes, O(n) storage on each executor, and O(n k) storage on the driver. * - * At most k largest non-zero singular values and associated vectors are returned. - * If there are k such values, then the dimensions of the return will be: + * Several internal parameters are set to default values. The reciprocal condition number rCond + * is set to 1e-9. All singular values smaller than rCond * sigma(0) are treated as zeros, where + * sigma(0) is the largest singular value. The maximum number of Arnoldi update iterations for + * ARPACK is set to 300 or k * 3, whichever is larger. The numerical tolerance for ARPACK's + * eigen-decomposition is set to 1e-10. * - * U is a RowMatrix of size m x k that satisfies U'U = eye(k), - * s is a Vector of size k, holding the singular values in descending order, - * and V is a Matrix of size n x k that satisfies V'V = eye(k). + * @note The conditions that decide which method to use internally and the default parameters are + * subject to change. * - * @param k number of singular values to keep. We might return less than k if there are - * numerically zero singular values. See rCond. + * @param k number of leading singular values to keep (0 < k <= n). It might return less than k if + * there are numerically zero singular values or there are not enough Ritz values + * converged before the maximum number of Arnoldi update iterations is reached (in case + * that matrix A is ill-conditioned). * @param computeU whether to compute U * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) * are treated as zero, where sigma(0) is the largest singular value. - * @return SingularValueDecomposition(U, s, V) + * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. */ def computeSVD( k: Int, computeU: Boolean = false, rCond: Double = 1e-9): SingularValueDecomposition[RowMatrix, Matrix] = { + // maximum number of Arnoldi update iterations for invoking ARPACK + val maxIter = math.max(300, k * 3) + // numerical tolerance for invoking ARPACK + val tol = 1e-10 + computeSVD(k, computeU, rCond, maxIter, tol, "auto") + } + + /** + * The actual SVD implementation, visible for testing. + * + * @param k number of leading singular values to keep (0 < k <= n) + * @param computeU whether to compute U + * @param rCond the reciprocal condition number + * @param maxIter max number of iterations (if ARPACK is used) + * @param tol termination tolerance (if ARPACK is used) + * @param mode computation mode (auto: determine automatically which mode to use, + * local-svd: compute gram matrix and computes its full SVD locally, + * local-eigs: compute gram matrix and computes its top eigenvalues locally, + * dist-eigs: compute the top eigenvalues of the gram matrix distributively) + * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. + */ + private[mllib] def computeSVD( + k: Int, + computeU: Boolean, + rCond: Double, + maxIter: Int, + tol: Double, + mode: String): SingularValueDecomposition[RowMatrix, Matrix] = { val n = numCols().toInt - require(k > 0 && k <= n, s"Request up to n singular values k=$k n=$n.") + require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.") - val G = computeGramianMatrix() + object SVDMode extends Enumeration { + val LocalARPACK, LocalLAPACK, DistARPACK = Value + } + + val computeMode = mode match { + case "auto" => + // TODO: The conditions below are not fully tested. + if (n < 100 || k > n / 2) { + // If n is small or k is large compared with n, we better compute the Gramian matrix first + // and then compute its eigenvalues locally, instead of making multiple passes. + if (k < n / 3) { + SVDMode.LocalARPACK + } else { + SVDMode.LocalLAPACK + } + } else { + // If k is small compared with n, we use ARPACK with distributed multiplication. + SVDMode.DistARPACK + } + case "local-svd" => SVDMode.LocalLAPACK + case "local-eigs" => SVDMode.LocalARPACK + case "dist-eigs" => SVDMode.DistARPACK + case _ => throw new IllegalArgumentException(s"Do not support mode $mode.") + } + + // Compute the eigen-decomposition of A' * A. + val (sigmaSquares: BDV[Double], u: BDM[Double]) = computeMode match { + case SVDMode.LocalARPACK => + require(k < n, s"k must be smaller than n in local-eigs mode but got k=$k and n=$n.") + val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] + EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) + case SVDMode.LocalLAPACK => + val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] + val (uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) + (sigmaSquaresFull, uFull) + case SVDMode.DistARPACK => + require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.") + EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter) + } - // TODO: Use sparse SVD instead. - val (u: BDM[Double], sigmaSquares: BDV[Double], v: BDM[Double]) = - brzSvd(G.toBreeze.asInstanceOf[BDM[Double]]) val sigmas: BDV[Double] = brzSqrt(sigmaSquares) - // Determine effective rank. + // Determine the effective rank. val sigma0 = sigmas(0) val threshold = rCond * sigma0 var i = 0 - while (i < k && sigmas(i) >= threshold) { + // sigmas might have a length smaller than k, if some Ritz values do not satisfy the convergence + // criterion specified by tol after max number of iterations. + // Thus use i < min(k, sigmas.length) instead of i < k. + if (sigmas.length < k) { + logWarning(s"Requested $k singular values but only found ${sigmas.length} converged.") + } + while (i < math.min(k, sigmas.length) && sigmas(i) >= threshold) { i += 1 } val sk = i @@ -272,12 +253,12 @@ class RowMatrix( logWarning(s"Requested $k singular values but only found $sk nonzeros.") } - val s = Vectors.dense(util.Arrays.copyOfRange(sigmas.data, 0, sk)) - val V = Matrices.dense(n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk)) + val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk)) + val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk)) if (computeU) { // N = Vk * Sk^{-1} - val N = new BDM[Double](n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk)) + val N = new BDM[Double](n, sk, Arrays.copyOfRange(u.data, 0, n * sk)) var i = 0 var j = 0 while (j < sk) { @@ -364,7 +345,7 @@ class RowMatrix( if (k == n) { Matrices.dense(n, k, u.data) } else { - Matrices.dense(n, k, util.Arrays.copyOfRange(u.data, 0, n * k)) + Matrices.dense(n, k, Arrays.copyOfRange(u.data, 0, n * k)) } } @@ -372,8 +353,7 @@ class RowMatrix( * Computes column-wise summary statistics. */ def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { - val zeroValue = new ColumnStatisticsAggregator(numCols().toInt) - val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)( + val summary = rows.aggregate[MultivariateOnlineSummarizer](new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2) ) @@ -390,15 +370,24 @@ class RowMatrix( */ def multiply(B: Matrix): RowMatrix = { val n = numCols().toInt + val k = B.numCols require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}") require(B.isInstanceOf[DenseMatrix], s"Only support dense matrix at this time but found ${B.getClass.getName}.") - val Bb = rows.context.broadcast(B) + val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray) val AB = rows.mapPartitions({ iter => - val Bi = Bb.value.toBreeze.asInstanceOf[BDM[Double]] - iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze)) + val Bi = Bb.value + iter.map(row => { + val v = BDV.zeros[Double](k) + var i = 0 + while (i < k) { + v(i) = row.toBreeze.dot(new BDV(Bi, i * n, 1, n)) + i += 1 + } + Vectors.fromBreeze(v) + }) }, preservesPartitioning = true) new RowMatrix(AB, nRows, B.numCols) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 8cca926f1c92e..fe41863bce985 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.regression -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.util.MLUtils._ /** * :: DeveloperApi :: @@ -124,16 +123,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] run(input, initialWeights) } - /** Prepends one to the input vector. */ - private def prependOne(vector: Vector): Vector = { - val vector1 = vector.toBreeze match { - case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv) - case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv) - case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) - } - Vectors.fromBreeze(vector1) - } - /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. @@ -147,23 +136,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features))) + input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - prependOne(initialWeights) + appendBias(initialWeights) } else { initialWeights } val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0 + val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 val weights = if (addIntercept) { - Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size)) + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) } else { weightsWithIntercept } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala new file mode 100644 index 0000000000000..5105b5c37aaaa --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import breeze.linalg.{DenseVector => BDV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vectors, Vector} + +/** + * :: DeveloperApi :: + * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean, + * variance, minimum, maximum, counts, and nonzero counts for samples in sparse or dense vector + * format in a online fashion. + * + * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of + * the corresponding joint dataset. + * + * A numerically stable algorithm is implemented to compute sample mean and variance: + * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] + * Zero elements (including explicit zero values) are skipped when calling add(), + * to have time complexity O(nnz) instead of O(n) for each column. + */ +@DeveloperApi +class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { + + private var n = 0 + private var currMean: BDV[Double] = _ + private var currM2n: BDV[Double] = _ + private var totalCnt: Long = 0 + private var nnz: BDV[Double] = _ + private var currMax: BDV[Double] = _ + private var currMin: BDV[Double] = _ + + /** + * Add a new sample to this summarizer, and update the statistical summary. + * + * @param sample The sample in dense/sparse vector format to be added into this summarizer. + * @return This MultivariateOnlineSummarizer object. + */ + def add(sample: Vector): this.type = { + if (n == 0) { + require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.") + n = sample.toBreeze.length + + currMean = BDV.zeros[Double](n) + currM2n = BDV.zeros[Double](n) + nnz = BDV.zeros[Double](n) + currMax = BDV.fill(n)(Double.MinValue) + currMin = BDV.fill(n)(Double.MaxValue) + } + + require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${sample.toBreeze.length}.") + + sample.toBreeze.activeIterator.foreach { + case (_, 0.0) => // Skip explicit zero elements. + case (i, value) => + if (currMax(i) < value) { + currMax(i) = value + } + if (currMin(i) > value) { + currMin(i) = value + } + + val tmpPrevMean = currMean(i) + currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) + currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) + + nnz(i) += 1.0 + } + + totalCnt += 1 + this + } + + /** + * Merge another MultivariateOnlineSummarizer, and update the statistical summary. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other MultivariateOnlineSummarizer to be merged. + * @return This MultivariateOnlineSummarizer object. + */ + def merge(other: MultivariateOnlineSummarizer): this.type = { + if (this.totalCnt != 0 && other.totalCnt != 0) { + require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + + s"Expecting $n but got ${other.n}.") + totalCnt += other.totalCnt + val deltaMean: BDV[Double] = currMean - other.currMean + var i = 0 + while (i < n) { + // merge mean together + if (other.currMean(i) != 0.0) { + currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / + (nnz(i) + other.nnz(i)) + } + // merge m2n together + if (nnz(i) + other.nnz(i) != 0.0) { + currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / + (nnz(i) + other.nnz(i)) + } + if (currMax(i) < other.currMax(i)) { + currMax(i) = other.currMax(i) + } + if (currMin(i) > other.currMin(i)) { + currMin(i) = other.currMin(i) + } + i += 1 + } + nnz += other.nnz + } else if (totalCnt == 0 && other.totalCnt != 0) { + this.n = other.n + this.currMean = other.currMean.copy + this.currM2n = other.currM2n.copy + this.totalCnt = other.totalCnt + this.nnz = other.nnz.copy + this.currMax = other.currMax.copy + this.currMin = other.currMin.copy + } + this + } + + override def mean: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + val realMean = BDV.zeros[Double](n) + var i = 0 + while (i < n) { + realMean(i) = currMean(i) * (nnz(i) / totalCnt) + i += 1 + } + Vectors.fromBreeze(realMean) + } + + override def variance: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + val realVariance = BDV.zeros[Double](n) + + val denominator = totalCnt - 1.0 + + // Sample variance is computed, if the denominator is less than 0, the variance is just 0. + if (denominator > 0.0) { + val deltaMean = currMean + var i = 0 + while (i < currM2n.size) { + realVariance(i) = + currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt + realVariance(i) /= denominator + i += 1 + } + } + + Vectors.fromBreeze(realVariance) + } + + override def count: Long = totalCnt + + override def numNonzeros: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + Vectors.fromBreeze(nnz) + } + + override def max: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + i += 1 + } + Vectors.fromBreeze(currMax) + } + + override def min: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + i += 1 + } + Vectors.fromBreeze(currMin) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala new file mode 100644 index 0000000000000..68f3867ba6c11 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.stat.correlation.Correlations +import org.apache.spark.rdd.RDD + +/** + * API for statistical functions in MLlib + */ +@Experimental +object Statistics { + + /** + * Compute the Pearson correlation matrix for the input RDD of Vectors. + * Returns NaN if either vector has 0 variance. + * + * @param X an RDD[Vector] for which the correlation matrix is to be computed. + * @return Pearson correlation matrix comparing columns in X. + */ + def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) + + /** + * Compute the correlation matrix for the input RDD of Vectors using the specified method. + * Methods currently supported: `pearson` (default), `spearman` + * + * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. + * + * @param X an RDD[Vector] for which the correlation matrix is to be computed. + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + * @return Correlation matrix comparing columns in X. + */ + def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) + + /** + * Compute the Pearson correlation for the input RDDs. + * Columns with 0 covariance produce NaN entries in the correlation matrix. + * + * @param x RDD[Double] of the same cardinality as y + * @param y RDD[Double] of the same cardinality as x + * @return A Double containing the Pearson correlation between the two input RDD[Double]s + */ + def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) + + /** + * Compute the correlation for the input RDDs using the specified method. + * Methods currently supported: pearson (default), spearman + * + * @param x RDD[Double] of the same cardinality as y + * @param y RDD[Double] of the same cardinality as x + * @param method String specifying the method to use for computing correlation. + * Supported: `pearson` (default), `spearman` + *@return A Double containing the correlation between the two input RDD[Double]s using the + * specified method. + */ + def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala new file mode 100644 index 0000000000000..f23393d3da257 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/Correlation.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.correlation + +import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} +import org.apache.spark.rdd.RDD + +/** + * Trait for correlation algorithms. + */ +private[stat] trait Correlation { + + /** + * Compute correlation for two datasets. + */ + def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double + + /** + * Compute the correlation matrix S, for the input matrix, where S(i, j) is the correlation + * between column i and j. S(i, j) can be NaN if the correlation is undefined for column i and j. + */ + def computeCorrelationMatrix(X: RDD[Vector]): Matrix + + /** + * Combine the two input RDD[Double]s into an RDD[Vector] and compute the correlation using the + * correlation implementation for RDD[Vector]. Can be NaN if correlation is undefined for the + * input vectors. + */ + def computeCorrelationWithMatrixImpl(x: RDD[Double], y: RDD[Double]): Double = { + val mat: RDD[Vector] = x.zip(y).map { case (xi, yi) => new DenseVector(Array(xi, yi)) } + computeCorrelationMatrix(mat)(0, 1) + } + +} + +/** + * Delegates computation to the specific correlation object based on the input method name + * + * Currently supported correlations: pearson, spearman. + * After new correlation algorithms are added, please update the documentation here and in + * Statistics.scala for the correlation APIs. + * + * Maintains the default correlation type, pearson + */ +private[stat] object Correlations { + + // Note: after new types of correlations are implemented, please update this map + val nameToObjectMap = Map(("pearson", PearsonCorrelation), ("spearman", SpearmanCorrelation)) + val defaultCorrName: String = "pearson" + val defaultCorr: Correlation = nameToObjectMap(defaultCorrName) + + def corr(x: RDD[Double], y: RDD[Double], method: String = defaultCorrName): Double = { + val correlation = getCorrelationFromName(method) + correlation.computeCorrelation(x, y) + } + + def corrMatrix(X: RDD[Vector], method: String = defaultCorrName): Matrix = { + val correlation = getCorrelationFromName(method) + correlation.computeCorrelationMatrix(X) + } + + /** + * Match input correlation name with a known name via simple string matching + * + * private to stat for ease of unit testing + */ + private[stat] def getCorrelationFromName(method: String): Correlation = { + try { + nameToObjectMap(method) + } catch { + case nse: NoSuchElementException => + throw new IllegalArgumentException("Unrecognized method name. Supported correlations: " + + nameToObjectMap.keys.mkString(", ")) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala new file mode 100644 index 0000000000000..23b291eee070b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.correlation + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector} +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.rdd.RDD + +/** + * Compute Pearson correlation for two RDDs of the type RDD[Double] or the correlation matrix + * for an RDD of the type RDD[Vector]. + * + * Definition of Pearson correlation can be found at + * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient + */ +private[stat] object PearsonCorrelation extends Correlation with Logging { + + /** + * Compute the Pearson correlation for two datasets. NaN if either vector has 0 variance. + */ + override def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double = { + computeCorrelationWithMatrixImpl(x, y) + } + + /** + * Compute the Pearson correlation matrix S, for the input matrix, where S(i, j) is the + * correlation between column i and j. 0 covariance results in a correlation value of Double.NaN. + */ + override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { + val rowMatrix = new RowMatrix(X) + val cov = rowMatrix.computeCovariance() + computeCorrelationMatrixFromCovariance(cov) + } + + /** + * Compute the Pearson correlation matrix from the covariance matrix. + * 0 covariance results in a correlation value of Double.NaN. + */ + def computeCorrelationMatrixFromCovariance(covarianceMatrix: Matrix): Matrix = { + val cov = covarianceMatrix.toBreeze.asInstanceOf[BDM[Double]] + val n = cov.cols + + // Compute the standard deviation on the diagonals first + var i = 0 + while (i < n) { + // TODO remove once covariance numerical issue resolved. + cov(i, i) = if (closeToZero(cov(i, i))) 0.0 else math.sqrt(cov(i, i)) + i +=1 + } + + // Loop through columns since cov is column major + var j = 0 + var sigma = 0.0 + var containNaN = false + while (j < n) { + sigma = cov(j, j) + i = 0 + while (i < j) { + val corr = if (sigma == 0.0 || cov(i, i) == 0.0) { + containNaN = true + Double.NaN + } else { + cov(i, j) / (sigma * cov(i, i)) + } + cov(i, j) = corr + cov(j, i) = corr + i += 1 + } + j += 1 + } + + // put 1.0 on the diagonals + i = 0 + while (i < n) { + cov(i, i) = 1.0 + i +=1 + } + + if (containNaN) { + logWarning("Pearson correlation matrix contains NaN values.") + } + + Matrices.fromBreeze(cov) + } + + private def closeToZero(value: Double, threshhold: Double = 1e-12): Boolean = { + math.abs(value) <= threshhold + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala new file mode 100644 index 0000000000000..88de2c82479b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.correlation + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, HashPartitioner} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} +import org.apache.spark.rdd.{CoGroupedRDD, RDD} + +/** + * Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix + * for an RDD of the type RDD[Vector]. + * + * Definition of Spearman's correlation can be found at + * http://en.wikipedia.org/wiki/Spearman's_rank_correlation_coefficient + */ +private[stat] object SpearmanCorrelation extends Correlation with Logging { + + /** + * Compute Spearman's correlation for two datasets. + */ + override def computeCorrelation(x: RDD[Double], y: RDD[Double]): Double = { + computeCorrelationWithMatrixImpl(x, y) + } + + /** + * Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the + * correlation between column i and j. + * + * Input RDD[Vector] should be cached or checkpointed if possible since it would be split into + * numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector]. + */ + override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { + val indexed = X.zipWithUniqueId() + + val numCols = X.first.size + if (numCols > 50) { + logWarning("Computing the Spearman correlation matrix can be slow for large RDDs with more" + + " than 50 columns.") + } + val ranks = new Array[RDD[(Long, Double)]](numCols) + + // Note: we use a for loop here instead of a while loop with a single index variable + // to avoid race condition caused by closure serialization + for (k <- 0 until numCols) { + val column = indexed.map { case (vector, index) => (vector(k), index) } + ranks(k) = getRanks(column) + } + + val ranksMat: RDD[Vector] = makeRankMatrix(ranks, X) + PearsonCorrelation.computeCorrelationMatrix(ranksMat) + } + + /** + * Compute the ranks for elements in the input RDD, using the average method for ties. + * + * With the average method, elements with the same value receive the same rank that's computed + * by taking the average of their positions in the sorted list. + * e.g. ranks([2, 1, 0, 2]) = [2.5, 1.0, 0.0, 2.5] + * Note that positions here are 0-indexed, instead of the 1-indexed as in the definition for + * ranks in the standard definition for Spearman's correlation. This does not affect the final + * results and is slightly more performant. + * + * @param indexed RDD[(Double, Long)] containing pairs of the format (originalValue, uniqueId) + * @return RDD[(Long, Double)] containing pairs of the format (uniqueId, rank), where uniqueId is + * copied from the input RDD. + */ + private def getRanks(indexed: RDD[(Double, Long)]): RDD[(Long, Double)] = { + // Get elements' positions in the sorted list for computing average rank for duplicate values + val sorted = indexed.sortByKey().zipWithIndex() + + val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => + // add an extra element to signify the end of the list so that flatMap can flush the last + // batch of duplicates + val padded = iter ++ + Iterator[((Double, Long), Long)](((Double.NaN, -1L), -1L)) + var lastVal = 0.0 + var firstRank = 0.0 + val idBuffer = new ArrayBuffer[Long]() + padded.flatMap { case ((v, id), rank) => + if (v == lastVal && id != Long.MinValue) { + idBuffer += id + Iterator.empty + } else { + val entries = if (idBuffer.size == 0) { + // edge case for the first value matching the initial value of lastVal + Iterator.empty + } else if (idBuffer.size == 1) { + Iterator((idBuffer(0), firstRank)) + } else { + val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 + idBuffer.map(id => (id, averageRank)) + } + lastVal = v + firstRank = rank + idBuffer.clear() + idBuffer += id + entries + } + } + } + ranks + } + + private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = { + val partitioner = new HashPartitioner(input.partitions.size) + val cogrouped = new CoGroupedRDD[Long](ranks, partitioner) + cogrouped.map { case (_, values: Seq[Seq[Double]]) => new DenseVector(values.flatten.toArray) } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 74d5d7ba10960..ad32e3f4560fe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -77,11 +77,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = - strategy.algo match { - case Classification => 2 * numBins * numFeatures - case Regression => 3 * numBins * numFeatures - } + val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, + strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, + strategy.algo) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -109,8 +107,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins, maxLevelForSingleGroup) + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -212,7 +210,7 @@ object DecisionTree extends Serializable with Logging { * @return a DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + new DecisionTree(strategy).train(input) } /** @@ -233,10 +231,33 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int): DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + val strategy = new Strategy(algo, impurity, maxDepth) + new DecisionTree(strategy).train(input) } + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value of 2. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) + new DecisionTree(strategy).train(input) + } /** * Method to train a decision tree model where the instances are represented as an RDD of @@ -250,6 +271,7 @@ object DecisionTree extends Serializable with Logging { * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -264,12 +286,13 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int, + numClassesForClassification: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, - categoricalFeaturesInfo) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo) + new DecisionTree(strategy).train(input) } private val InvalidBinIndex = -1 @@ -381,6 +404,14 @@ object DecisionTree extends Serializable with Logging { logDebug("numFeatures = " + numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification + logDebug("numClasses = " + numClasses) + val isMulticlassClassification = strategy.isMulticlassClassification + logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlassClassificationWithCategoricalFeatures + = strategy.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + + isMulticlassClassificationWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -436,10 +467,8 @@ object DecisionTree extends Serializable with Logging { /** * Find bin for one feature. */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean): Int = { + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -467,17 +496,28 @@ object DecisionTree extends Serializable with Logging { -1 } + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeature(): Int = { - val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val bin = bins(featureIndex)(binIndex) - val category = bin.category + val categories = bin.highSplit.categories val features = labeledPoint.features - if (category == features(featureIndex)) { + if (categories.contains(features(featureIndex))) { return binIndex } binIndex += 1 @@ -494,7 +534,13 @@ object DecisionTree extends Serializable with Logging { binIndex } else { // Perform sequential search to find bin for categorical features. - val binIndex = sequentialBinSearchForCategoricalFeature() + val binIndex = { + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeatureInClassification() + } + } if (binIndex == -1){ throw new UnknownError("no bin was found for categorical variable.") } @@ -506,13 +552,16 @@ object DecisionTree extends Serializable with Logging { * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1. + * where b_ij is an integer between 0 and numBins - 1 for regressions and binary + * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) + // First element of the array is the label of the instance. arr(0) = labeledPoint.label + // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) @@ -525,8 +574,19 @@ object DecisionTree extends Serializable with Logging { } else { var featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isSpaceSufficientForAllCategoricalSplits) + } featureIndex += 1 } } @@ -535,18 +595,61 @@ object DecisionTree extends Serializable with Logging { arr } + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, + label: Double, featureIndex: Int) = { + + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val labelInt = label.toInt + agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 + } + + def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], + label: Double, agg: Array[Double], rightChildShift: Int) = { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + // Find all matching bins and increment their values + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val labelInt = label.toInt + if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + 1 + } else { + agg(rightChildShift + aggIndex + binIndex) + = agg(rightChildShift + aggIndex + binIndex) + 1 + } + binIndex += 1 + } + } + /** * Performs a sequential aggregation over a partition for classification. For l nodes, * k features, either the left count or the right count of one of the p bins is * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures*numNodes for classification + * numClasses * numSplits * numFeatures*numNodes for classification * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size * 2 * numSplits * numFeatures * numNodes for classification */ - def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -559,15 +662,52 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = 2 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - label match { - case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 - case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numClasses * numSplits * numFeatures * numNodes for classification + */ + def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + val rightChildShift = numClasses * numBins * numFeatures * numNodes + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + } else { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, + rightChildShift) + } else { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + } } featureIndex += 1 } @@ -586,7 +726,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 3 * numSplits * numFeatures * numNodes for regression */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -620,17 +760,20 @@ object DecisionTree extends Serializable with Logging { */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { - case Classification => classificationBinSeqOp(arr, agg) + case Classification => + if(isMulticlassClassificationWithCategoricalFeatures) { + unorderedClassificationBinSeqOp(arr, agg) + } else { + orderedClassificationBinSeqOp(arr, agg) + } case Regression => regressionBinSeqOp(arr, agg) } agg } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = strategy.algo match { - case Classification => 2 * numBins * numFeatures * numNodes - case Regression => 3 * numBins * numFeatures * numNodes - } + val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, + isMulticlassClassificationWithCategoricalFeatures, strategy.algo) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -649,9 +792,6 @@ object DecisionTree extends Serializable with Logging { combinedAggregate } - // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) - // Calculate bin aggregates. val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) @@ -668,42 +808,55 @@ object DecisionTree extends Serializable with Logging { * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], featureIndex: Int, splitIndex: Int, - rightNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Array[Double]]], topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) - val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) - val leftCount = left0Count + left1Count - - val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) - val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) - val rightCount = right0Count + right1Count + var classIndex = 0 + val leftCounts: Array[Double] = new Array[Double](numClasses) + val rightCounts: Array[Double] = new Array[Double](numClasses) + var leftTotalCount = 0.0 + var rightTotalCount = 0.0 + while (classIndex < numClasses) { + val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) + val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) + leftCounts(classIndex) = leftClassCount + leftTotalCount += leftClassCount + rightCounts(classIndex) = rightClassCount + rightTotalCount += rightClassCount + classIndex += 1 + } val impurity = { if (level > 0) { topImpurity } else { // Calculate impurity for root node. - strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) + classIndex += 1 + } + strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } } - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + if (leftTotalCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + if (rightTotalCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) } - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) + val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) + val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) val gain = { if (level > 0) { @@ -713,17 +866,34 @@ object DecisionTree extends Serializable with Logging { } } - val predict = (left1Count + right1Count) / (leftCount + rightCount) + val totalCount = leftTotalCount + rightTotalCount - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + // Sum of count for each label + val leftRightCounts: Array[Double] + = leftCounts.zip(rightCounts) + .map{case (leftCount, rightCount) => leftCount + rightCount} + + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) + else (maxIndex, maxValue, currentIndex + 1) + } + if (result._1 < 0) 0 else result._1 + } + + val predict = indexOfLargestArrayElement(leftRightCounts) + val prob = leftRightCounts(predict) / totalCount + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => - val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) - val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) - val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) + val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) + val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) - val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) - val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) - val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) + val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) + val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) val impurity = { if (level > 0) { @@ -768,104 +938,149 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], - * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) */ def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + + def findAggForOrderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + + var classIndex = 0 + while (classIndex < numClasses) { + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(classIndex) + = binData(shift + (numClasses * (numBins - 1)) + classIndex) + classIndex += 1 + } + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + var innerClassIndex = 0 + while (innerClassIndex < numClasses) { + leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) + = binData(shift + numClasses * splitIndex + innerClassIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = + binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) + innerClassIndex += 1 + } + splitIndex += 1 + } + } + + def findAggForUnorderedFeatureClassification( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + val rightChildShift = numClasses * numBins * numFeatures + var splitIndex = 0 + while (splitIndex < numBins - 1) { + var classIndex = 0 + while (classIndex < numClasses) { + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + splitIndex * numClasses + val leftBinValue = binData(shift + classIndex) + val rightBinValue = binData(rightChildShift + shift + classIndex) + leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue + rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue + classIndex += 1 + } + splitIndex += 1 + } + } + + def findAggForRegression( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(numBins - 2)(2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + var i = 0 // index for regression histograms + while (i < 3) { // count, sum, sum^2 + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + + leftNodeAgg(featureIndex)(splitIndex - 1)(i) + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = + binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) + i += 1 + } + splitIndex += 1 + } + } + strategy.algo match { case Classification => // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - // Iterate over all features. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(2 * (numBins - 2)) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) - = binData(shift + (2 * (numBins - 1)) + 1) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = - binData(shift + (2 *(numBins - 1 - splitIndex))) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (2* (numBins - 1 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) - - splitIndex += 1 + if (isMulticlassClassificationWithCategoricalFeatures){ + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } + } + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } featureIndex += 1 } + (leftNodeAgg, rightNodeAgg) case Regression => // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(3 * (numBins - 2)) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3) - leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) - leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = - binData(shift + (3 * (numBins - 1 - splitIndex))) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) - - splitIndex += 1 - } + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) featureIndex += 1 } (leftNodeAgg, rightNodeAgg) @@ -876,8 +1091,8 @@ object DecisionTree extends Serializable with Logging { * Calculates information gain for all nodes splits. */ def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Double]], - rightNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) @@ -918,7 +1133,22 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - while (splitIndex < numBins - 1) { + val maxSplitIndex : Double = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { // Binary classification + featureCategories + } + } + } + while (splitIndex < maxSplitIndex) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -944,9 +1174,23 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => - val shift = 2 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) - binsForNode + if (isMulticlassClassificationWithCategoricalFeatures) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData + } + binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } case Regression => val shift = 3 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) @@ -963,14 +1207,26 @@ object DecisionTree extends Serializable with Logging { val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("node impurity = " + parentNodeImpurity) + logDebug("parent node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } - bestSplits } + private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + algo match { + case Classification => + if (isMulticlassClassificationWithCategoricalFeatures) { + 2 * numClasses * numBins * numFeatures + } else { + numClasses * numBins * numFeatures + } + case Regression => 3 * numBins * numFeatures + } + } + /** * Returns split and bins for decision tree calculation. * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -992,17 +1248,23 @@ object DecisionTree extends Serializable with Logging { val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) + val isMulticlassClassification = strategy.isMulticlassClassification + logDebug("isMulticlassClassification = " + isMulticlassClassification) + /* - * TODO: Add a require statement ensuring #bins is always greater than the categories. + * Ensure #bins is always greater than the categories. For multiclass classification, + * #bins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 - require(numBins >= maxCategoriesForFeatures) + require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + + "in categorical features") } + // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1036,48 +1298,93 @@ object DecisionTree extends Serializable with Logging { val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } - } else { - val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number " + - "of bins") - - // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - val centroidForCategories = - sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, - categoriesForSplit) + } else { // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + + // Use different bin/split calculation strategy for categorical features in multiclass + // classification that satisfy the space constraint + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + // 2^(maxFeatureValue- 1) - 1 combinations + var index = 0 + while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { + val categories: List[Double] + = extractMultiClassCategories(index + 1, featureCategories) + splits(featureIndex)(index) + = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) + new Bin( + new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), + Categorical, + Double.MinValue) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) + new Bin( + splits(featureIndex)(index - 1), + splits(featureIndex)(index), + Categorical, + Double.MinValue) } } + index += 1 + } + } else { + + val centroidForCategories = { + if (isMulticlassClassification) { + // For categorical variables in multiclass classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the impurity of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) + .map(x => (x._1, x._2.values.toArray)) + .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + } + + logDebug("centriod for categories = " + centroidForCategories.mkString(",")) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until featureCategories) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, + Categorical, categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } } } featureIndex += 1 @@ -1107,4 +1414,29 @@ object DecisionTree extends Serializable with Logging { throw new UnsupportedOperationException("approximate histogram not supported yet.") } } + + /** + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ + private[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { + var categories = List[Double]() + var j = 0 + var bitShiftedInput = input + while (j < maxFeatureValue) { + if (bitShiftedInput % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + // Right shift by one + bitShiftedInput = bitShiftedInput >> 1 + j += 1 + } + categories + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 1b505fd76eb75..7c027ac2fda6b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and the @@ -44,7 +46,15 @@ class Strategy ( val algo: Algo, val impurity: Impurity, val maxDepth: Int, + val numClassesForClassification: Int = 2, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable + val maxMemoryInMB: Int = 128) extends Serializable { + + require(numClassesForClassification >= 2) + val isMulticlassClassification = numClassesForClassification > 2 + val isMulticlassWithCategoricalFeatures + = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 60f43e9278d2a..a0e2d91762782 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -31,23 +31,35 @@ object Entropy extends Impurity { /** * :: DeveloperApi :: - * entropy calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return entropy value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - -(f0 * log2(f0)) - (f1 * log2(f1)) + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 0.0 + var classIndex = 0 + while (classIndex < numClasses) { + val classCount = counts(classIndex) + if (classCount != 0) { + val freq = classCount / totalCount + impurity -= freq * log2(freq) + } + classIndex += 1 } + impurity } + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c51d76d9b4c5b..48144b5e6d1e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -30,23 +30,32 @@ object Gini extends Impurity { /** * :: DeveloperApi :: - * Gini coefficient calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return Gini coefficient value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - 1 - f0 * f0 - f1 * f1 + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 1.0 + var classIndex = 0 + while (classIndex < numClasses) { + val freq = counts(classIndex) / totalCount + impurity -= freq * freq + classIndex += 1 } + impurity } + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 8eab247cf0932..7b2a9320cc21d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -28,13 +28,13 @@ trait Impurity extends Serializable { /** * :: DeveloperApi :: - * information calculation for binary classification - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels * @return information value */ @DeveloperApi - def calculate(c0 : Double, c1 : Double): Double + def calculate(counts: Array[Double], totalCount: Double): Double /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 47d07122af30f..97149a99ead59 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} */ @Experimental object Variance extends Impurity { - override def calculate(c0: Double, c1: Double): Double = + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value + */ + @DeveloperApi + override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 2d71e1e366069..c89c1e371a40e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin + * @param category categorical label value accepted in the bin for binary classification */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index cc8a24cce9614..fb12298e0f5d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -27,6 +27,7 @@ import org.apache.spark.annotation.DeveloperApi * @param leftImpurity left node impurity * @param rightImpurity right node impurity * @param predict predicted value + * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( @@ -34,10 +35,11 @@ class InformationGainStats( val impurity: Double, val leftImpurity: Double, val rightImpurity: Double, - val predict: Double) extends Serializable { + val predict: Double, + val prob: Double = 0.0) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 560a4ad71a4de..76a3bdf9b11c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -61,6 +61,32 @@ class KMeansSuite extends FunSuite with LocalSparkContext { assert(model.clusterCenters.head === center) } + test("no distinct points") { + val data = sc.parallelize( + Array( + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 2.0, 3.0)), + 2) + val center = Vectors.dense(1.0, 2.0, 3.0) + + // Make sure code runs. + var model = KMeans.train(data, k=2, maxIterations=1) + assert(model.clusterCenters.size === 2) + } + + test("more clusters than points") { + val data = sc.parallelize( + Array( + Vectors.dense(1.0, 2.0, 3.0), + Vectors.dense(1.0, 3.0, 4.0)), + 2) + + // Make sure code runs. + var model = KMeans.train(data, k=3, maxIterations=1) + assert(model.clusterCenters.size === 3) + } + test("single cluster with big dataset") { val smallData = Array( Vectors.dense(1.0, 2.0, 6.0), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala new file mode 100644 index 0000000000000..1ea503971c864 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Matrices +import org.apache.spark.mllib.util.LocalSparkContext + +class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { + test("Multiclass evaluation metrics") { + /* + * Confusion matrix for 3-class classification with total 9 instances: + * |2|1|1| true class0 (4 instances) + * |1|3|0| true class1 (4 instances) + * |0|0|1| true class2 (1 instance) + */ + val confusionMatrix = Matrices.dense(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1)) + val labels = Array(0.0, 1.0, 2.0) + val predictionAndLabels = sc.parallelize( + Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0), + (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2) + val metrics = new MulticlassMetrics(predictionAndLabels) + val delta = 0.0000001 + val fpRate0 = 1.0 / (9 - 4) + val fpRate1 = 1.0 / (9 - 4) + val fpRate2 = 1.0 / (9 - 1) + val precision0 = 2.0 / (2 + 1) + val precision1 = 3.0 / (3 + 1) + val precision2 = 1.0 / (1 + 1) + val recall0 = 2.0 / (2 + 2) + val recall1 = 3.0 / (3 + 1) + val recall2 = 1.0 / (1 + 0) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0) + val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1) + val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2) + + assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray)) + assert(math.abs(metrics.falsePositiveRate(0.0) - fpRate0) < delta) + assert(math.abs(metrics.falsePositiveRate(1.0) - fpRate1) < delta) + assert(math.abs(metrics.falsePositiveRate(2.0) - fpRate2) < delta) + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.fMeasure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.fMeasure(0.0, 2.0) - f2measure0) < delta) + assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) + assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) + + assert(math.abs(metrics.recall - + (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) + assert(math.abs(metrics.recall - metrics.precision) < delta) + assert(math.abs(metrics.recall - metrics.fMeasure) < delta) + assert(math.abs(metrics.recall - metrics.weightedRecall) < delta) + assert(math.abs(metrics.weightedFalsePositiveRate - + ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) + assert(math.abs(metrics.weightedPrecision - + ((4.0 / 9) * precision0 + (4.0 / 9) * precision1 + (1.0 / 9) * precision2)) < delta) + assert(math.abs(metrics.weightedRecall - + ((4.0 / 9) * recall0 + (4.0 / 9) * recall1 + (1.0 / 9) * recall2)) < delta) + assert(math.abs(metrics.weightedFMeasure - + ((4.0 / 9) * f1measure0 + (4.0 / 9) * f1measure1 + (1.0 / 9) * f1measure2)) < delta) + assert(math.abs(metrics.weightedFMeasure(2.0) - + ((4.0 / 9) * f2measure0 + (4.0 / 9) * f2measure1 + (1.0 / 9) * f2measure2)) < delta) + assert(metrics.labels.sameElements(labels)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index c9f9acf4c1335..a961f89456a18 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -96,37 +96,44 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { test("svd of a full-rank matrix") { for (mat <- Seq(denseMat, sparseMat)) { - val localMat = mat.toBreeze() - val (localU, localSigma, localVt) = brzSvd(localMat) - val localV: BDM[Double] = localVt.t.toDenseMatrix - for (k <- 1 to n) { - val svd = mat.computeSVD(k, computeU = true) - val U = svd.U - val s = svd.s - val V = svd.V - assert(U.numRows() === m) - assert(U.numCols() === k) - assert(s.size === k) - assert(V.numRows === n) - assert(V.numCols === k) - assertColumnEqualUpToSign(U.toBreeze(), localU, k) - assertColumnEqualUpToSign(V.toBreeze.asInstanceOf[BDM[Double]], localV, k) - assert(closeToZero(s.toBreeze.asInstanceOf[BDV[Double]] - localSigma(0 until k))) + for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { + val localMat = mat.toBreeze() + val (localU, localSigma, localVt) = brzSvd(localMat) + val localV: BDM[Double] = localVt.t.toDenseMatrix + for (k <- 1 to n) { + val skip = (mode == "local-eigs" || mode == "dist-eigs") && k == n + if (!skip) { + val svd = mat.computeSVD(k, computeU = true, 1e-9, 300, 1e-10, mode) + val U = svd.U + val s = svd.s + val V = svd.V + assert(U.numRows() === m) + assert(U.numCols() === k) + assert(s.size === k) + assert(V.numRows === n) + assert(V.numCols === k) + assertColumnEqualUpToSign(U.toBreeze(), localU, k) + assertColumnEqualUpToSign(V.toBreeze.asInstanceOf[BDM[Double]], localV, k) + assert(closeToZero(s.toBreeze.asInstanceOf[BDV[Double]] - localSigma(0 until k))) + } + } + val svdWithoutU = mat.computeSVD(1, computeU = false, 1e-9, 300, 1e-10, mode) + assert(svdWithoutU.U === null) } - val svdWithoutU = mat.computeSVD(n) - assert(svdWithoutU.U === null) } } test("svd of a low-rank matrix") { - val rows = sc.parallelize(Array.fill(4)(Vectors.dense(1.0, 1.0)), 2) - val mat = new RowMatrix(rows, 4, 2) - val svd = mat.computeSVD(2, computeU = true) - assert(svd.s.size === 1, "should not return zero singular values") - assert(svd.U.numRows() === 4) - assert(svd.U.numCols() === 1) - assert(svd.V.numRows === 2) - assert(svd.V.numCols === 1) + val rows = sc.parallelize(Array.fill(4)(Vectors.dense(1.0, 1.0, 1.0)), 2) + val mat = new RowMatrix(rows, 4, 3) + for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { + val svd = mat.computeSVD(2, computeU = true, 1e-6, 300, 1e-10, mode) + assert(svd.s.size === 1, s"should not return zero singular values but got ${svd.s}") + assert(svd.U.numRows() === 4) + assert(svd.U.numCols() === 1) + assert(svd.V.numRows === 3) + assert(svd.V.numCols === 1) + } } def closeToZero(G: BDM[Double]): Boolean = { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala new file mode 100644 index 0000000000000..bce4251426df7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, + SpearmanCorrelation} +import org.apache.spark.mllib.util.LocalSparkContext + +class CorrelationSuite extends FunSuite with LocalSparkContext { + + // test input data + val xData = Array(1.0, 0.0, -2.0) + val yData = Array(4.0, 5.0, 3.0) + val data = Seq( + Vectors.dense(1.0, 0.0, 0.0, -2.0), + Vectors.dense(4.0, 5.0, 0.0, 3.0), + Vectors.dense(6.0, 7.0, 0.0, 8.0), + Vectors.dense(9.0, 0.0, 0.0, 1.0) + ) + + test("corr(x, y) default, pearson") { + val x = sc.parallelize(xData) + val y = sc.parallelize(yData) + val expected = 0.6546537 + val default = Statistics.corr(x, y) + val p1 = Statistics.corr(x, y, "pearson") + assert(approxEqual(expected, default)) + assert(approxEqual(expected, p1)) + } + + test("corr(x, y) spearman") { + val x = sc.parallelize(xData) + val y = sc.parallelize(yData) + val expected = 0.5 + val s1 = Statistics.corr(x, y, "spearman") + assert(approxEqual(expected, s1)) + } + + test("corr(X) default, pearson") { + val X = sc.parallelize(data) + val defaultMat = Statistics.corr(X) + val pearsonMat = Statistics.corr(X, "pearson") + val expected = BDM( + (1.00000000, 0.05564149, Double.NaN, 0.4004714), + (0.05564149, 1.00000000, Double.NaN, 0.9135959), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.40047142, 0.91359586, Double.NaN,1.0000000)) + assert(matrixApproxEqual(defaultMat.toBreeze, expected)) + assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) + } + + test("corr(X) spearman") { + val X = sc.parallelize(data) + val spearmanMat = Statistics.corr(X, "spearman") + val expected = BDM( + (1.0000000, 0.1054093, Double.NaN, 0.4000000), + (0.1054093, 1.0000000, Double.NaN, 0.9486833), + (Double.NaN, Double.NaN, 1.00000000, Double.NaN), + (0.4000000, 0.9486833, Double.NaN, 1.0000000)) + assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) + } + + test("method identification") { + val pearson = PearsonCorrelation + val spearman = SpearmanCorrelation + + assert(Correlations.getCorrelationFromName("pearson") === pearson) + assert(Correlations.getCorrelationFromName("spearman") === spearman) + + // Should throw IllegalArgumentException + try { + Correlations.getCorrelationFromName("kendall") + assert(false) + } catch { + case ie: IllegalArgumentException => + } + } + + def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = { + if (v1.isNaN) { + v2.isNaN + } else { + math.abs(v1 - v2) <= threshold + } + } + + def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = { + for (i <- 0 until A.rows; j <- 0 until A.cols) { + if (!approxEqual(A(i, j), B(i, j), threshold)) { + println("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j)) + return false + } + } + true + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala new file mode 100644 index 0000000000000..4b7b019d820b4 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.TestingUtils._ + +class MultivariateOnlineSummarizerSuite extends FunSuite { + + test("basic error handing") { + val summarizer = new MultivariateOnlineSummarizer + + assert(summarizer.count === 0, "should be zero since nothing is added.") + + withClue("Getting numNonzeros from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.numNonzeros + } + } + + withClue("Getting variance from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.variance + } + } + + withClue("Getting mean from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.mean + } + } + + withClue("Getting max from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.max + } + } + + withClue("Getting min from empty summarizer should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.min + } + } + + summarizer.add(Vectors.dense(-1.0, 2.0, 6.0)).add(Vectors.sparse(3, Seq((0, -2.0), (1, 6.0)))) + + withClue("Adding a new dense sample with different array size should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.add(Vectors.dense(3.0, 1.0)) + } + } + + withClue("Adding a new sparse sample with different array size should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.add(Vectors.sparse(5, Seq((0, -2.0), (1, 6.0)))) + } + } + + val summarizer2 = (new MultivariateOnlineSummarizer).add(Vectors.dense(1.0, -2.0, 0.0, 4.0)) + withClue("Merging a new summarizer with different dimensions should throw exception.") { + intercept[IllegalArgumentException] { + summarizer.merge(summarizer2) + } + } + } + + test("dense vector input") { + // For column 2, the maximum will be 0.0, and it's not explicitly added since we ignore all + // the zeros; it's a case we need to test. For column 3, the minimum will be 0.0 which we + // need to test as well. + val summarizer = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(-1.0, 0.0, 6.0)) + .add(Vectors.dense(3.0, -3.0, 0.0)) + + assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + + assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + + assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + + assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + + assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + + assert(summarizer.count === 2) + } + + test("sparse vector input") { + val summarizer = (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -1.0), (2, 6.0)))) + .add(Vectors.sparse(3, Seq((0, 3.0), (1, -3.0)))) + + assert(summarizer.mean.almostEquals(Vectors.dense(1.0, -1.5, 3.0)), "mean mismatch") + + assert(summarizer.min.almostEquals(Vectors.dense(-1.0, -3, 0.0)), "min mismatch") + + assert(summarizer.max.almostEquals(Vectors.dense(3.0, 0.0, 6.0)), "max mismatch") + + assert(summarizer.numNonzeros.almostEquals(Vectors.dense(2, 1, 1)), "numNonzeros mismatch") + + assert(summarizer.variance.almostEquals(Vectors.dense(8.0, 4.5, 18.0)), "variance mismatch") + + assert(summarizer.count === 2) + } + + test("mixing dense and sparse vector input") { + val summarizer = (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3)))) + .add(Vectors.dense(0.0, -1.0, -3.0)) + .add(Vectors.sparse(3, Seq((1, -5.1)))) + .add(Vectors.dense(3.8, 0.0, 1.9)) + .add(Vectors.dense(1.7, -0.6, 0.0)) + .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) + + assert(summarizer.mean.almostEquals( + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + + assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + + assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + + assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + + assert(summarizer.variance.almostEquals( + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + + assert(summarizer.count === 6) + } + + test("merging two summarizers") { + val summarizer1 = (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3)))) + .add(Vectors.dense(0.0, -1.0, -3.0)) + + val summarizer2 = (new MultivariateOnlineSummarizer) + .add(Vectors.sparse(3, Seq((1, -5.1)))) + .add(Vectors.dense(3.8, 0.0, 1.9)) + .add(Vectors.dense(1.7, -0.6, 0.0)) + .add(Vectors.sparse(3, Seq((1, 1.9), (2, 0.0)))) + + val summarizer = summarizer1.merge(summarizer2) + + assert(summarizer.mean.almostEquals( + Vectors.dense(0.583333333333, -0.416666666666, -0.183333333333)), "mean mismatch") + + assert(summarizer.min.almostEquals(Vectors.dense(-2.0, -5.1, -3)), "min mismatch") + + assert(summarizer.max.almostEquals(Vectors.dense(3.8, 2.3, 1.9)), "max mismatch") + + assert(summarizer.numNonzeros.almostEquals(Vectors.dense(3, 5, 2)), "numNonzeros mismatch") + + assert(summarizer.variance.almostEquals( + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666)), "variance mismatch") + + assert(summarizer.count === 6) + } + + test("merging summarizer with empty summarizer") { + // If one of two is non-empty, this should return the non-empty summarizer. + // If both of them are empty, then just return the empty summarizer. + val summarizer1 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(0.0, -1.0, -3.0)).merge(new MultivariateOnlineSummarizer) + assert(summarizer1.count === 1) + + val summarizer2 = (new MultivariateOnlineSummarizer) + .merge((new MultivariateOnlineSummarizer).add(Vectors.dense(0.0, -1.0, -3.0))) + assert(summarizer2.count === 1) + + val summarizer3 = (new MultivariateOnlineSummarizer).merge(new MultivariateOnlineSummarizer) + assert(summarizer3.count === 0) + + assert(summarizer1.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + + assert(summarizer2.mean.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "mean mismatch") + + assert(summarizer1.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + + assert(summarizer2.min.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "min mismatch") + + assert(summarizer1.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + + assert(summarizer2.max.almostEquals(Vectors.dense(0.0, -1.0, -3.0)), "max mismatch") + + assert(summarizer1.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + + assert(summarizer2.numNonzeros.almostEquals(Vectors.dense(0, 1, 1)), "numNonzeros mismatch") + + assert(summarizer1.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + + assert(summarizer2.variance.almostEquals(Vectors.dense(0, 0, 0)), "variance mismatch") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 35e92d71dc63f..5961a618c59d9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split @@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(bins.length === 2) @@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -130,8 +131,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) // Check splits. @@ -231,6 +233,162 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(1)(3) === null) } + test("extract categories from a number for multiclass classification") { + val l = DecisionTree.extractMultiClassCategories(13, 10) + assert(l.length === 3) + assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + } + + test("split and bin calculations for unordered categorical variables with multiclass " + + "classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + numClassesForClassification = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Expecting 2^2 - 1 = 3 bins/splits + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(0.0)) + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 1) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 1) + assert(splits(1)(1).categories.contains(1.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 2) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 2) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(1.0)) + + assert(splits(0)(3) === null) + assert(splits(1)(3) === null) + + + // Check bins. + + assert(bins(0)(0).category === Double.MinValue) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(0.0)) + assert(bins(1)(0).category === Double.MinValue) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(0)(1).category === Double.MinValue) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(0.0)) + assert(bins(0)(1).highSplit.categories.length === 1) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(1)(1).category === Double.MinValue) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 1) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(0)(2).category === Double.MinValue) + assert(bins(0)(2).lowSplit.categories.length === 1) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.length === 2) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).category === Double.MinValue) + assert(bins(1)(2).lowSplit.categories.length === 1) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length === 2) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + + assert(bins(0)(3) === null) + assert(bins(1)(3) === null) + + } + + test("split and bin calculations for ordered categorical variables with multiclass " + + "classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + assert(arr.length === 3000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + numClassesForClassification = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // 2^10 - 1 > 100, so categorical variables will be ordered + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(2.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(2.0)) + assert(splits(0)(2).categories.contains(1.0)) + + assert(splits(0)(10) === null) + assert(splits(1)(10) === null) + + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + assert(bins(0)(1).category === 2.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(2.0)) + + assert(bins(0)(10) === null) + + } + + test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) @@ -238,6 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, + numClassesForClassification = 2, maxDepth = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -253,8 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict > 0.4) - assert(stats.predict < 0.5) + assert(stats.predict === 1) + assert(stats.prob == 0.6) assert(stats.impurity > 0.2) } @@ -280,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict > 0.4) - assert(stats.predict < 0.5) + assert(stats.predict == 0.6) assert(stats.impurity > 0.2) } @@ -289,7 +447,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -312,7 +470,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -336,7 +494,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -360,7 +518,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -380,11 +538,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("test second level node building with/without groups") { + test("second level node building with/without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -426,6 +584,82 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } + test("stump with categorical variables for multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1)) + assert(bestSplit.featureType === Categorical) + } + + test("stump with continuous variables for multiclass classification") { + val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + + assert(bestSplit.feature === 1) + assert(bestSplit.featureType === Continuous) + assert(bestSplit.threshold > 1980) + assert(bestSplit.threshold < 2020) + + } + + test("stump with continuous + categorical variables for multiclass classification") { + val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + + assert(bestSplit.feature === 1) + assert(bestSplit.featureType === Continuous) + assert(bestSplit.threshold > 1980) + assert(bestSplit.threshold < 2020) + } + + test("stump with categorical variables for ordered multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(1.0)) + assert(bestSplit.featureType === Categorical) + } + + } object DecisionTreeSuite { @@ -473,4 +707,47 @@ object DecisionTreeSuite { } arr } + + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } + } + arr + } + + def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 2000) { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i)) + } else { + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i)) + } + } + arr + } + + def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): + Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0)) + } + } + arr + } + + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala similarity index 50% rename from core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala rename to mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index c4a8996c0b9a9..64b1ba7527183 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -15,22 +15,31 @@ * limitations under the License. */ -package org.apache.spark.ui.jobs +package org.apache.spark.mllib.util -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector -/** - * :: DeveloperApi :: - * Class for reporting aggregated metrics for each executor in stage UI. - */ -@DeveloperApi -class ExecutorSummary { - var taskTime : Long = 0 - var failedTasks : Int = 0 - var succeededTasks : Int = 0 - var inputBytes: Long = 0 - var shuffleRead : Long = 0 - var shuffleWrite : Long = 0 - var memoryBytesSpilled : Long = 0 - var diskBytesSpilled : Long = 0 +object TestingUtils { + + implicit class DoubleWithAlmostEquals(val x: Double) { + // An improved version of AlmostEquals would always divide by the larger number. + // This will avoid the problem of diving by zero. + def almostEquals(y: Double, epsilon: Double = 1E-10): Boolean = { + if(x == y) { + true + } else if(math.abs(x) > math.abs(y)) { + math.abs(x - y) / math.abs(x) < epsilon + } else { + math.abs(x - y) / math.abs(y) < epsilon + } + } + } + + implicit class VectorWithAlmostEquals(val x: Vector) { + def almostEquals(y: Vector, epsilon: Double = 1E-10): Boolean = { + x.toArray.corresponds(y.toArray) { + _.almostEquals(_, epsilon) + } + } + } } diff --git a/pom.xml b/pom.xml index 05f76d566e9d1..4e2d64a833640 100644 --- a/pom.xml +++ b/pom.xml @@ -110,7 +110,7 @@ UTF-8 1.6 - + spark 2.10.4 2.10 0.18.1 @@ -297,6 +297,11 @@ snappy-java 1.0.5 + + net.jpountz.lz4 + lz4 + 1.2.0 + com.clearspring.analytics stream @@ -535,6 +540,10 @@ org.mortbay.jetty servlet-api-2.5 + + javax.servlet + servlet-api + junit junit @@ -605,7 +614,6 @@ net.java.dev.jets3t jets3t ${jets3t.version} - runtime commons-logging @@ -618,6 +626,10 @@ hadoop-yarn-api ${yarn.version} + + javax.servlet + servlet-api + asm asm diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index bb2d73741c3bf..034ba6a7bf50f 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -15,13 +15,16 @@ * limitations under the License. */ +import sbt._ +import sbt.Keys.version + import com.typesafe.tools.mima.core._ import com.typesafe.tools.mima.core.MissingClassProblem import com.typesafe.tools.mima.core.MissingTypesProblem import com.typesafe.tools.mima.core.ProblemFilters._ import com.typesafe.tools.mima.plugin.MimaKeys.{binaryIssueFilters, previousArtifact} import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings -import sbt._ + object MimaBuild { @@ -53,7 +56,7 @@ object MimaBuild { excludePackage("org.apache.spark." + packageName) } - def ignoredABIProblems(base: File) = { + def ignoredABIProblems(base: File, currentSparkVersion: String) = { // Excludes placed here will be used for all Spark versions val defaultExcludes = Seq() @@ -77,11 +80,16 @@ object MimaBuild { } defaultExcludes ++ ignoredClasses.flatMap(excludeClass) ++ - ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes + ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes(currentSparkVersion) + } + + def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { + val organization = "org.apache.spark" + val previousSparkVersion = "1.0.0" + val fullId = "spark-" + projectRef.project + "_2.10" + mimaDefaultSettings ++ + Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), + binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) } - def mimaSettings(sparkHome: File) = mimaDefaultSettings ++ Seq( - previousArtifact := None, - binaryIssueFilters ++= ignoredABIProblems(sparkHome) - ) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1621833e124f5..e0f433b26f7ff 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -31,72 +31,99 @@ import com.typesafe.tools.mima.core._ * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") */ object MimaExcludes { - val excludes = - SparkBuild.SPARK_VERSION match { - case v if v.startsWith("1.1") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - Seq( - // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), - // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values - // for countApproxDistinct* functions, which does not work in Java. We later removed - // them, and use the following to tell Mima to not care about them. - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" - + "createZero$1") - ) ++ - Seq( // Ignore some private methods in ALS. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. - "org.apache.spark.mllib.recommendation.ALS.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7") - ) ++ - MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ - MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ - MimaBuild.excludeSparkClass("storage.Values") ++ - MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") - case v if v.startsWith("1.0") => - Seq( - MimaBuild.excludeSparkPackage("api.java"), - MimaBuild.excludeSparkPackage("mllib"), - MimaBuild.excludeSparkPackage("streaming") - ) ++ - MimaBuild.excludeSparkClass("rdd.ClassTags") ++ - MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ - MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ - MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ - MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ - MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ - MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") - case _ => Seq() - } + + def excludes(version: String) = version match { + case v if v.startsWith("1.1") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + closures.map(method => ProblemFilters.exclude[MissingMethodProblem](method)) ++ + Seq( + // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), + // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values + // for countApproxDistinct* functions, which does not work in Java. We later removed + // them, and use the following to tell Mima to not care about them. + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.storage.MemoryStore.Entry"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$" + + "createZero$1") + ) ++ + Seq( + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.FlumeReceiver.this") + ) ++ + Seq( // Ignore some private methods in ALS. + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), + ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. + "org.apache.spark.mllib.recommendation.ALS.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7") + ) ++ + MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ + MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ + MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ + MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ + MimaBuild.excludeSparkClass("storage.Values") ++ + MimaBuild.excludeSparkClass("storage.Entry") ++ + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) + case v if v.startsWith("1.0") => + Seq( + MimaBuild.excludeSparkPackage("api.java"), + MimaBuild.excludeSparkPackage("mllib"), + MimaBuild.excludeSparkPackage("streaming") + ) ++ + MimaBuild.excludeSparkClass("rdd.ClassTags") ++ + MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ + MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ + MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ + MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ + MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ + MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ + MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ + MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + case _ => Seq() + } + + private val closures = Seq( + "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$mergeMaps$1", + "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$countPartition$1", + "org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$distributePartition$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeValue$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeToFile$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$reducePartition$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$writeShard$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeCombiners$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$process$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$createCombiner$1", + "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$mergeMaps$1" + ) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 599714233c18f..5461d25d72d7e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -15,498 +15,177 @@ * limitations under the License. */ +import scala.util.Properties +import scala.collection.JavaConversions._ + import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ -import sbtassembly.Plugin._ -import AssemblyKeys._ -import scala.util.Properties import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings} -import com.typesafe.tools.mima.plugin.MimaKeys.previousArtifact -import sbtunidoc.Plugin._ -import UnidocKeys._ - -import scala.collection.JavaConversions._ - -// For Sonatype publishing -// import com.jsuereth.pgp.sbtplugin.PgpKeys._ - -object SparkBuild extends Build { - val SPARK_VERSION = "1.1.0-SNAPSHOT" - val SPARK_VERSION_SHORT = SPARK_VERSION.replaceAll("-SNAPSHOT", "") - - // Hadoop version to build against. For example, "1.0.4" for Apache releases, or - // "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set - // through the environment variables SPARK_HADOOP_VERSION and SPARK_YARN. - val DEFAULT_HADOOP_VERSION = "1.0.4" - - // Whether the Hadoop version to build against is 2.2.x, or a variant of it. This can be set - // through the SPARK_IS_NEW_HADOOP environment variable. - val DEFAULT_IS_NEW_HADOOP = false - - val DEFAULT_YARN = false - - val DEFAULT_HIVE = false - - // HBase version; set as appropriate. - val HBASE_VERSION = "0.94.6" - - // Target JVM version - val SCALAC_JVM_VERSION = "jvm-1.6" - val JAVAC_JVM_VERSION = "1.6" - - lazy val root = Project("root", file("."), settings = rootSettings) aggregate(allProjects: _*) +import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} +import net.virtualvoid.sbt.graph.Plugin.graphSettings - lazy val core = Project("core", file("core"), settings = coreSettings) +object BuildCommons { - /** Following project only exists to pull previous artifacts of Spark for generating - Mima ignores. For more information see: SPARK 2071 */ - lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings) - - def replDependencies = Seq[ProjectReference](core, graphx, bagel, mllib, sql) ++ maybeHiveRef - - lazy val repl = Project("repl", file("repl"), settings = replSettings) - .dependsOn(replDependencies.map(a => a: sbt.ClasspathDep[sbt.ProjectReference]): _*) - - lazy val tools = Project("tools", file("tools"), settings = toolsSettings) dependsOn(core) dependsOn(streaming) - - lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn(core) - - lazy val graphx = Project("graphx", file("graphx"), settings = graphxSettings) dependsOn(core) + private val buildLocation = file(".").getAbsoluteFile.getParentFile - lazy val catalyst = Project("catalyst", file("sql/catalyst"), settings = catalystSettings) dependsOn(core) + val allProjects@Seq(bagel, catalyst, core, graphx, hive, mllib, repl, spark, sql, streaming, + streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = + Seq("bagel", "catalyst", "core", "graphx", "hive", "mllib", "repl", "spark", "sql", + "streaming", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", + "streaming-zeromq").map(ProjectRef(buildLocation, _)) - lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core) dependsOn(catalyst % "compile->compile;test->test") + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = + Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") + .map(ProjectRef(buildLocation, _)) - lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql) + val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") + .map(ProjectRef(buildLocation, _)) - lazy val maybeHive: Seq[ClasspathDependency] = if (isHiveEnabled) Seq(hive) else Seq() - lazy val maybeHiveRef: Seq[ProjectReference] = if (isHiveEnabled) Seq(hive) else Seq() + val tools = "tools" - lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn(core) + val sparkHome = buildLocation +} - lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core) +object SparkBuild extends PomBuild { - lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) - .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*) + import BuildCommons._ + import scala.collection.mutable.Map - lazy val assembleDepsTask = TaskKey[Unit]("assemble-deps") - lazy val assembleDeps = assembleDepsTask := { - println() - println("**** NOTE ****") - println("'sbt/sbt assemble-deps' is no longer supported.") - println("Instead create a normal assembly and:") - println(" export SPARK_PREPEND_CLASSES=1 (toggle on)") - println(" unset SPARK_PREPEND_CLASSES (toggle off)") - println() - } + val projectsMap: Map[String, Seq[Setting[_]]] = Map.empty - // A configuration to set an alternative publishLocalConfiguration - lazy val MavenCompile = config("m2r") extend(Compile) - lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - val sparkHome = System.getProperty("user.dir") - - // Allows build configuration to be set through environment variables - lazy val hadoopVersion = Properties.envOrElse("SPARK_HADOOP_VERSION", DEFAULT_HADOOP_VERSION) - lazy val isNewHadoop = Properties.envOrNone("SPARK_IS_NEW_HADOOP") match { - case None => { - val isNewHadoopVersion = "^2\\.[2-9]+".r.findFirstIn(hadoopVersion).isDefined - (isNewHadoopVersion|| DEFAULT_IS_NEW_HADOOP) + // Provides compatibility for older versions of the Spark build + def backwardCompatibility = { + import scala.collection.mutable + var isAlphaYarn = false + var profiles: mutable.Seq[String] = mutable.Seq.empty + if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { + println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pganglia-lgpl flag.") + profiles ++= Seq("spark-ganglia-lgpl") + } + if (Properties.envOrNone("SPARK_HIVE").isDefined) { + println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.") + profiles ++= Seq("hive") + } + Properties.envOrNone("SPARK_HADOOP_VERSION") match { + case Some(v) => + if (v.matches("0.23.*")) isAlphaYarn = true + println("NOTE: SPARK_HADOOP_VERSION is deprecated, please use -Dhadoop.version=" + v) + System.setProperty("hadoop.version", v) + case None => + } + if (Properties.envOrNone("SPARK_YARN").isDefined) { + if(isAlphaYarn) { + println("NOTE: SPARK_YARN is deprecated, please use -Pyarn-alpha flag.") + profiles ++= Seq("yarn-alpha") + } + else { + println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.") + profiles ++= Seq("yarn") + } } - case Some(v) => v.toBoolean + profiles } - lazy val isYarnEnabled = Properties.envOrNone("SPARK_YARN") match { - case None => DEFAULT_YARN - case Some(v) => v.toBoolean + override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { + case None => backwardCompatibility + case Some(v) => + if (backwardCompatibility.nonEmpty) + println("Note: We ignore environment variables, when use of profile is detected in " + + "conjunction with environment variable.") + v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq } - lazy val hadoopClient = if (hadoopVersion.startsWith("0.20.") || hadoopVersion == "1.0.0") "hadoop-core" else "hadoop-client" - val maybeAvro = if (hadoopVersion.startsWith("0.23.")) Seq("org.apache.avro" % "avro" % "1.7.4") else Seq() - lazy val isHiveEnabled = Properties.envOrNone("SPARK_HIVE") match { - case None => DEFAULT_HIVE - case Some(v) => v.toBoolean + Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { + case Some(v) => + v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.split("=")).foreach(x => System.setProperty(x(0), x(1))) + case _ => } - // Include Ganglia integration if the user has enabled Ganglia - // This is isolated from the normal build due to LGPL-licensed code in the library - lazy val isGangliaEnabled = Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined - lazy val gangliaProj = Project("spark-ganglia-lgpl", file("extras/spark-ganglia-lgpl"), settings = gangliaSettings).dependsOn(core) - val maybeGanglia: Seq[ClasspathDependency] = if (isGangliaEnabled) Seq(gangliaProj) else Seq() - val maybeGangliaRef: Seq[ProjectReference] = if (isGangliaEnabled) Seq(gangliaProj) else Seq() + override val userPropertiesMap = System.getProperties.toMap - // Include the Java 8 project if the JVM version is 8+ - lazy val javaVersion = System.getProperty("java.specification.version") - lazy val isJava8Enabled = javaVersion.toDouble >= "1.8".toDouble - val maybeJava8Tests = if (isJava8Enabled) Seq[ProjectReference](java8Tests) else Seq[ProjectReference]() - lazy val java8Tests = Project("java8-tests", file("extras/java8-tests"), settings = java8TestsSettings). - dependsOn(core) dependsOn(streaming % "compile->compile;test->test") - - // Include the YARN project if the user has enabled YARN - lazy val yarnAlpha = Project("yarn-alpha", file("yarn/alpha"), settings = yarnAlphaSettings) dependsOn(core) - lazy val yarn = Project("yarn", file("yarn/stable"), settings = yarnSettings) dependsOn(core) - - lazy val maybeYarn: Seq[ClasspathDependency] = if (isYarnEnabled) Seq(if (isNewHadoop) yarn else yarnAlpha) else Seq() - lazy val maybeYarnRef: Seq[ProjectReference] = if (isYarnEnabled) Seq(if (isNewHadoop) yarn else yarnAlpha) else Seq() - - lazy val externalTwitter = Project("external-twitter", file("external/twitter"), settings = twitterSettings) - .dependsOn(streaming % "compile->compile;test->test") - - lazy val externalKafka = Project("external-kafka", file("external/kafka"), settings = kafkaSettings) - .dependsOn(streaming % "compile->compile;test->test") - - lazy val externalFlume = Project("external-flume", file("external/flume"), settings = flumeSettings) - .dependsOn(streaming % "compile->compile;test->test") - - lazy val externalZeromq = Project("external-zeromq", file("external/zeromq"), settings = zeromqSettings) - .dependsOn(streaming % "compile->compile;test->test") - - lazy val externalMqtt = Project("external-mqtt", file("external/mqtt"), settings = mqttSettings) - .dependsOn(streaming % "compile->compile;test->test") - - lazy val allExternal = Seq[ClasspathDependency](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt) - lazy val allExternalRefs = Seq[ProjectReference](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt) - - lazy val examples = Project("examples", file("examples"), settings = examplesSettings) - .dependsOn(core, mllib, graphx, bagel, streaming, hive) dependsOn(allExternal: _*) - - // Everything except assembly, hive, tools, java8Tests and examples belong to packageProjects - lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx, catalyst, sql) ++ maybeYarnRef ++ maybeHiveRef ++ maybeGangliaRef - - lazy val allProjects = packageProjects ++ allExternalRefs ++ - Seq[ProjectReference](examples, tools, assemblyProj) ++ maybeJava8Tests + lazy val MavenCompile = config("m2r") extend(Compile) + lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq( - organization := "org.apache.spark", - version := SPARK_VERSION, - scalaVersion := "2.10.4", - scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-feature", - "-target:" + SCALAC_JVM_VERSION), - javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION), - unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, + lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ Seq ( + javaHome := Properties.envOrNone("JAVA_HOME").map(file), + incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, - javaHome := Properties.envOrNone("JAVA_HOME").map(file), - // This is to add convenience of enabling sbt -Dsbt.offline=true for making the build offline. - offline := "true".equalsIgnoreCase(sys.props("sbt.offline")), retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - transitiveClassifiers in Scope.GlobalScope := Seq("sources"), - testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), - incOptions := incOptions.value.withNameHashing(true), - // Fork new JVMs for tests and set Java options for those - fork := true, - javaOptions in Test += "-Dspark.home=" + sparkHome, - javaOptions in Test += "-Dspark.testing=1", - javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", - javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark").map { case (k,v) => s"-D$k=$v" }.toSeq, - javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g".split(" ").toSeq, - javaOptions += "-Xmx3g", - // Show full stack trace and duration in test cases. - testOptions in Test += Tests.Argument("-oDF"), - // Remove certain packages from Scaladoc - scalacOptions in (Compile, doc) := Seq( - "-groups", - "-skip-packages", Seq( - "akka", - "org.apache.spark.api.python", - "org.apache.spark.network", - "org.apache.spark.deploy", - "org.apache.spark.util.collection" - ).mkString(":"), - "-doc-title", "Spark " + SPARK_VERSION_SHORT + " ScalaDoc" - ), - - // Only allow one test at a time, even across projects, since they run in the same JVM - concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), - - resolvers ++= Seq( - // HTTPS is unavailable for Maven Central - "Maven Repository" at "http://repo.maven.apache.org/maven2", - "Apache Repository" at "https://repository.apache.org/content/repositories/releases", - "JBoss Repository" at "https://repository.jboss.org/nexus/content/repositories/releases/", - "MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/", - "Cloudera Repository" at "http://repository.cloudera.com/artifactory/cloudera-repos/", - "Pivotal Repository" at "http://repo.spring.io/libs-release/", - // For Sonatype publishing - // "sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", - // "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/", - // also check the local Maven repository ~/.m2 - Resolver.mavenLocal - ), - publishMavenStyle := true, - // useGpg in Global := true, - - pomExtra := ( - - org.apache - apache - 14 - - http://spark.apache.org/ - - - Apache 2.0 License - http://www.apache.org/licenses/LICENSE-2.0.html - repo - - - - scm:git:git@github.com:apache/spark.git - scm:git:git@github.com:apache/spark.git - - - - matei - Matei Zaharia - matei.zaharia@gmail.com - http://www.cs.berkeley.edu/~matei - Apache Software Foundation - http://spark.apache.org - - - - JIRA - https://issues.apache.org/jira/browse/SPARK - - ), - - /* - publishTo <<= version { (v: String) => - val nexus = "https://oss.sonatype.org/" - if (v.trim.endsWith("SNAPSHOT")) - Some("sonatype-snapshots" at nexus + "content/repositories/snapshots") - else - Some("sonatype-staging" at nexus + "service/local/staging/deploy/maven2") - }, - - */ - - libraryDependencies ++= Seq( - "io.netty" % "netty-all" % "4.0.17.Final", - "org.eclipse.jetty" % "jetty-server" % jettyVersion, - "org.eclipse.jetty" % "jetty-util" % jettyVersion, - "org.eclipse.jetty" % "jetty-plus" % jettyVersion, - "org.eclipse.jetty" % "jetty-security" % jettyVersion, - "org.scalatest" %% "scalatest" % "2.1.5" % "test", - "org.scalacheck" %% "scalacheck" % "1.11.3" % "test", - "com.novocode" % "junit-interface" % "0.10" % "test", - "org.easymock" % "easymockclassextension" % "3.1" % "test", - "org.mockito" % "mockito-all" % "1.9.0" % "test", - "junit" % "junit" % "4.10" % "test", - // Needed by cglib which is needed by easymock. - "asm" % "asm" % "3.3.1" % "test" - ), - - testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), - parallelExecution := true, - /* Workaround for issue #206 (fixed after SBT 0.11.0) */ - watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task, - const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) }, - - otherResolvers := Seq(Resolver.file("dotM2", file(Path.userHome + "/.m2/repository"))), + otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) }, publishMavenStyle in MavenCompile := true, publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn - ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings ++ ScalaStyleSettings ++ genjavadocSettings - - val akkaVersion = "2.2.3-shaded-protobuf" - val chillVersion = "0.3.6" - val codahaleMetricsVersion = "3.0.0" - val jblasVersion = "1.2.3" - val jets3tVersion = if ("^2\\.[3-9]+".r.findFirstIn(hadoopVersion).isDefined) "0.9.0" else "0.7.1" - val jettyVersion = "8.1.14.v20131031" - val hiveVersion = "0.12.0" - val parquetVersion = "1.4.3" - val slf4jVersion = "1.7.5" - - val excludeJBossNetty = ExclusionRule(organization = "org.jboss.netty") - val excludeIONetty = ExclusionRule(organization = "io.netty") - val excludeEclipseJetty = ExclusionRule(organization = "org.eclipse.jetty") - val excludeAsm = ExclusionRule(organization = "org.ow2.asm") - val excludeOldAsm = ExclusionRule(organization = "asm") - val excludeCommonsLogging = ExclusionRule(organization = "commons-logging") - val excludeSLF4J = ExclusionRule(organization = "org.slf4j") - val excludeScalap = ExclusionRule(organization = "org.scala-lang", artifact = "scalap") - val excludeHadoop = ExclusionRule(organization = "org.apache.hadoop") - val excludeCurator = ExclusionRule(organization = "org.apache.curator") - val excludePowermock = ExclusionRule(organization = "org.powermock") - val excludeFastutil = ExclusionRule(organization = "it.unimi.dsi") - val excludeJruby = ExclusionRule(organization = "org.jruby") - val excludeThrift = ExclusionRule(organization = "org.apache.thrift") - val excludeServletApi = ExclusionRule(organization = "javax.servlet", artifact = "servlet-api") - val excludeJUnit = ExclusionRule(organization = "junit") - - def sparkPreviousArtifact(id: String, organization: String = "org.apache.spark", - version: String = "1.0.0", crossVersion: String = "2.10"): Option[sbt.ModuleID] = { - val fullId = if (crossVersion.isEmpty) id else id + "_" + crossVersion - Some(organization % fullId % version) // the artifact to compare binary compatibility with + ) + + /** Following project only exists to pull previous artifacts of Spark for generating + Mima ignores. For more information see: SPARK 2071 */ + lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings) + + def versionArtifact(id: String): Option[sbt.ModuleID] = { + val fullId = id + "_2.10" + Some("org.apache.spark" % fullId % "1.0.0") } - def coreSettings = sharedSettings ++ Seq( - name := "spark-core", - libraryDependencies ++= Seq( - "com.google.guava" % "guava" % "14.0.1", - "org.apache.commons" % "commons-lang3" % "3.3.2", - "org.apache.commons" % "commons-math3" % "3.3" % "test", - "com.google.code.findbugs" % "jsr305" % "1.3.9", - "log4j" % "log4j" % "1.2.17", - "org.slf4j" % "slf4j-api" % slf4jVersion, - "org.slf4j" % "slf4j-log4j12" % slf4jVersion, - "org.slf4j" % "jul-to-slf4j" % slf4jVersion, - "org.slf4j" % "jcl-over-slf4j" % slf4jVersion, - "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407 - "com.ning" % "compress-lzf" % "1.0.0", - "org.xerial.snappy" % "snappy-java" % "1.0.5", - "org.spark-project.akka" %% "akka-remote" % akkaVersion, - "org.spark-project.akka" %% "akka-slf4j" % akkaVersion, - "org.spark-project.akka" %% "akka-testkit" % akkaVersion % "test", - "org.json4s" %% "json4s-jackson" % "3.2.6" excludeAll(excludeScalap), - "colt" % "colt" % "1.2.0", - "org.apache.mesos" % "mesos" % "0.18.1" classifier("shaded-protobuf") exclude("com.google.protobuf", "protobuf-java"), - "commons-net" % "commons-net" % "2.2", - "net.java.dev.jets3t" % "jets3t" % jets3tVersion excludeAll(excludeCommonsLogging), - "commons-codec" % "commons-codec" % "1.5", // Prevent jets3t from including the older version of commons-codec - "org.apache.derby" % "derby" % "10.4.2.0" % "test", - "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J, excludeOldAsm, excludeServletApi), - "org.apache.curator" % "curator-recipes" % "2.4.0" excludeAll(excludeJBossNetty), - "com.codahale.metrics" % "metrics-core" % codahaleMetricsVersion, - "com.codahale.metrics" % "metrics-jvm" % codahaleMetricsVersion, - "com.codahale.metrics" % "metrics-json" % codahaleMetricsVersion, - "com.codahale.metrics" % "metrics-graphite" % codahaleMetricsVersion, - "com.twitter" %% "chill" % chillVersion excludeAll(excludeAsm), - "com.twitter" % "chill-java" % chillVersion excludeAll(excludeAsm), - "org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock), - "com.clearspring.analytics" % "stream" % "2.7.0" excludeAll(excludeFastutil), // Only HyperLogLogPlus is used, which does not depend on fastutil. - "org.spark-project" % "pyrolite" % "2.0.1", - "net.sf.py4j" % "py4j" % "0.8.1" - ), - libraryDependencies ++= maybeAvro, - assembleDeps, - previousArtifact := sparkPreviousArtifact("spark-core") + def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + name := "old-deps", + scalaVersion := "2.10.4", + retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", + "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-core").map(versionArtifact(_).get intransitive()) ) - // Create a colon-separate package list adding "org.apache.spark" in front of all of them, - // for easier specification of JavaDoc package groups - def packageList(names: String*): String = { - names.map(s => "org.apache.spark." + s).mkString(":") + def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { + val existingSettings = projectsMap.getOrElse(projectRef.project, Seq[Setting[_]]()) + projectsMap += (projectRef.project -> (existingSettings ++ settings)) } - def rootSettings = sharedSettings ++ scalaJavaUnidocSettings ++ Seq( - publish := {}, + // Note ordering of these settings matter. + /* Enable shared settings on all projects */ + (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings)) - unidocProjectFilter in (ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, examples, tools, catalyst, yarn, yarnAlpha), - unidocProjectFilter in (JavaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, examples, bagel, graphx, catalyst, tools, yarn, yarnAlpha), + /* Enable tests settings for all projects except examples, assembly and tools */ + (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // Skip class names containing $ and some internal packages in Javadocs - unidocAllSources in (JavaUnidoc, unidoc) := { - (unidocAllSources in (JavaUnidoc, unidoc)).value - .map(_.filterNot(_.getName.contains("$"))) - .map(_.filterNot(_.getCanonicalPath.contains("akka"))) - .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) - .map(_.filterNot(_.getCanonicalPath.contains("network"))) - .map(_.filterNot(_.getCanonicalPath.contains("executor"))) - .map(_.filterNot(_.getCanonicalPath.contains("python"))) - .map(_.filterNot(_.getCanonicalPath.contains("collection"))) - }, + /* Enable Mima for all projects except spark, hive, catalyst, sql and repl */ + // TODO: Add Sql to mima checks + allProjects.filterNot(y => Seq(spark, sql, hive, catalyst, repl).exists(x => x == y)). + foreach (x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) - // Javadoc options: create a window title, and group key packages on index page - javacOptions in doc := Seq( - "-windowtitle", "Spark " + SPARK_VERSION_SHORT + " JavaDoc", - "-public", - "-group", "Core Java API", packageList("api.java", "api.java.function"), - "-group", "Spark Streaming", packageList( - "streaming.api.java", "streaming.flume", "streaming.kafka", - "streaming.mqtt", "streaming.twitter", "streaming.zeromq" - ), - "-group", "MLlib", packageList( - "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg", - "mllib.linalg.distributed", "mllib.optimization", "mllib.rdd", "mllib.recommendation", - "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration", - "mllib.tree.impurity", "mllib.tree.model", "mllib.util" - ), - "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"), - "-noqualifier", "java.lang" - ) - ) + /* Enable Assembly for all assembly projects */ + assemblyProjects.foreach(enable(Assembly.settings)) - def replSettings = sharedSettings ++ Seq( - name := "spark-repl", - libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v), - libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "jline" % v), - libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v) - ) + /* Enable unidoc only for the root spark project */ + enable(Unidoc.settings)(spark) - def examplesSettings = sharedSettings ++ Seq( - name := "spark-examples", - jarName in assembly <<= version map { - v => "spark-examples-" + v + "-hadoop" + hadoopVersion + ".jar" }, - libraryDependencies ++= Seq( - "com.twitter" %% "algebird-core" % "0.1.11", - "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeIONetty, excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging, excludeJruby), - "org.apache.cassandra" % "cassandra-all" % "1.2.6" - exclude("com.google.guava", "guava") - exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") - exclude("com.ning","compress-lzf") - exclude("io.netty", "netty") - exclude("jline","jline") - exclude("org.apache.cassandra.deps", "avro") - excludeAll(excludeSLF4J, excludeIONetty), - "com.github.scopt" %% "scopt" % "3.2.0" - ) - ) ++ assemblySettings ++ extraAssemblySettings - - def toolsSettings = sharedSettings ++ Seq( - name := "spark-tools", - libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v), - libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v ) - ) ++ assemblySettings ++ extraAssemblySettings - - def graphxSettings = sharedSettings ++ Seq( - name := "spark-graphx", - previousArtifact := sparkPreviousArtifact("spark-graphx"), - libraryDependencies ++= Seq( - "org.jblas" % "jblas" % jblasVersion - ) - ) + /* Spark SQL Core console settings */ + enable(SQL.settings)(sql) - def bagelSettings = sharedSettings ++ Seq( - name := "spark-bagel", - previousArtifact := sparkPreviousArtifact("spark-bagel") - ) + /* Hive console settings */ + enable(Hive.settings)(hive) - def mllibSettings = sharedSettings ++ Seq( - name := "spark-mllib", - previousArtifact := sparkPreviousArtifact("spark-mllib"), - libraryDependencies ++= Seq( - "org.jblas" % "jblas" % jblasVersion, - "org.scalanlp" %% "breeze" % "0.7" excludeAll(excludeJUnit) - ) - ) + // TODO: move this to its upstream project. + override def projectDefinitions(baseDirectory: File): Seq[Project] = { + super.projectDefinitions(baseDirectory).map { x => + if (projectsMap.exists(_._1 == x.id)) x.settings(projectsMap(x.id): _*) + else x.settings(Seq[Setting[_]](): _*) + } ++ Seq[Project](oldDeps) + } - def catalystSettings = sharedSettings ++ Seq( - name := "catalyst", - // The mechanics of rewriting expression ids to compare trees in some test cases makes - // assumptions about the the expression ids being contiguous. Running tests in parallel breaks - // this non-deterministically. TODO: FIX THIS. - parallelExecution in Test := false, - libraryDependencies ++= Seq( - "com.typesafe" %% "scalalogging-slf4j" % "1.0.1" - ) - ) +} + +object SQL { + + lazy val settings = Seq( - def sqlCoreSettings = sharedSettings ++ Seq( - name := "spark-sql", - libraryDependencies ++= Seq( - "com.twitter" % "parquet-column" % parquetVersion, - "com.twitter" % "parquet-hadoop" % parquetVersion, - "com.fasterxml.jackson.core" % "jackson-databind" % "2.3.0" // json4s-jackson 3.2.6 requires jackson-databind 2.3.0. - ), initialCommands in console := """ |import org.apache.spark.sql.catalyst.analysis._ @@ -522,17 +201,14 @@ object SparkBuild extends Build { |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin ) - // Since we don't include hive in the main assembly this project also acts as an alternative - // assembly jar. - def hiveSettings = sharedSettings ++ Seq( - name := "spark-hive", +} + +object Hive { + + lazy val settings = Seq( + javaOptions += "-XX:MaxPermSize=1g", - libraryDependencies ++= Seq( - "org.spark-project.hive" % "hive-metastore" % hiveVersion, - "org.spark-project.hive" % "hive-exec" % hiveVersion excludeAll(excludeCommonsLogging), - "org.spark-project.hive" % "hive-serde" % hiveVersion - ), - // Multiple queries rely on the TestHive singleton. See comments there for more details. + // Multiple queries rely on the TestHive singleton. See comments there for more details. parallelExecution in Test := false, // Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings // only for this subproject. @@ -555,67 +231,16 @@ object SparkBuild extends Build { |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin ) - def streamingSettings = sharedSettings ++ Seq( - name := "spark-streaming", - previousArtifact := sparkPreviousArtifact("spark-streaming") - ) - - def yarnCommonSettings = sharedSettings ++ Seq( - unmanagedSourceDirectories in Compile <++= baseDirectory { base => - Seq( - base / "../common/src/main/scala" - ) - }, - - unmanagedSourceDirectories in Test <++= baseDirectory { base => - Seq( - base / "../common/src/test/scala" - ) - } - - ) ++ extraYarnSettings - - def yarnAlphaSettings = yarnCommonSettings ++ Seq( - name := "spark-yarn-alpha" - ) - - def yarnSettings = yarnCommonSettings ++ Seq( - name := "spark-yarn" - ) - - def gangliaSettings = sharedSettings ++ Seq( - name := "spark-ganglia-lgpl", - libraryDependencies += "com.codahale.metrics" % "metrics-ganglia" % "3.0.0" - ) - - def java8TestsSettings = sharedSettings ++ Seq( - name := "java8-tests", - javacOptions := Seq("-target", "1.8", "-source", "1.8"), - testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a") - ) - - // Conditionally include the YARN dependencies because some tools look at all sub-projects and will complain - // if we refer to nonexistent dependencies (e.g. hadoop-yarn-api from a Hadoop version without YARN). - def extraYarnSettings = if(isYarnEnabled) yarnEnabledSettings else Seq() - - def yarnEnabledSettings = Seq( - libraryDependencies ++= Seq( - // Exclude rule required for all ? - "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm), - "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging), - "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging), - "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging), - "org.apache.hadoop" % "hadoop-yarn-server-web-proxy" % hadoopVersion excludeAll(excludeJBossNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging, excludeServletApi) - ) - ) +} - def assemblyProjSettings = sharedSettings ++ Seq( - name := "spark-assembly", - jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" } - ) ++ assemblySettings ++ extraAssemblySettings +object Assembly { + import sbtassembly.Plugin._ + import AssemblyKeys._ - def extraAssemblySettings() = Seq( + lazy val settings = assemblySettings ++ Seq( test in assembly := {}, + jarName in assembly <<= (version, moduleName) map { (v, mName) => mName + "-"+v + "-hadoop" + + Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -627,57 +252,95 @@ object SparkBuild extends Build { } ) - def oldDepsSettings() = Defaults.defaultSettings ++ Seq( - name := "old-deps", - scalaVersion := "2.10.4", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", - "spark-core").map(sparkPreviousArtifact(_).get intransitive()) - ) +} - def twitterSettings() = sharedSettings ++ Seq( - name := "spark-streaming-twitter", - previousArtifact := sparkPreviousArtifact("spark-streaming-twitter"), - libraryDependencies ++= Seq( - "org.twitter4j" % "twitter4j-stream" % "3.0.3" - ) - ) +object Unidoc { - def kafkaSettings() = sharedSettings ++ Seq( - name := "spark-streaming-kafka", - previousArtifact := sparkPreviousArtifact("spark-streaming-kafka"), - libraryDependencies ++= Seq( - "com.github.sgroschupf" % "zkclient" % "0.1", - "org.apache.kafka" %% "kafka" % "0.8.0" - exclude("com.sun.jdmk", "jmxtools") - exclude("com.sun.jmx", "jmxri") - exclude("net.sf.jopt-simple", "jopt-simple") - excludeAll(excludeSLF4J) - ) - ) + import BuildCommons._ + import sbtunidoc.Plugin._ + import UnidocKeys._ + + // for easier specification of JavaDoc package groups + private def packageList(names: String*): String = { + names.map(s => "org.apache.spark." + s).mkString(":") + } - def flumeSettings() = sharedSettings ++ Seq( - name := "spark-streaming-flume", - previousArtifact := sparkPreviousArtifact("spark-streaming-flume"), - libraryDependencies ++= Seq( - "org.apache.flume" % "flume-ng-sdk" % "1.4.0" % "compile" excludeAll(excludeIONetty, excludeThrift) + lazy val settings = scalaJavaUnidocSettings ++ Seq ( + publish := {}, + + unidocProjectFilter in(ScalaUnidoc, unidoc) := + inAnyProject -- inProjects(repl, examples, tools, catalyst, yarn, yarnAlpha), + unidocProjectFilter in(JavaUnidoc, unidoc) := + inAnyProject -- inProjects(repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), + + // Skip class names containing $ and some internal packages in Javadocs + unidocAllSources in (JavaUnidoc, unidoc) := { + (unidocAllSources in (JavaUnidoc, unidoc)).value + .map(_.filterNot(_.getName.contains("$"))) + .map(_.filterNot(_.getCanonicalPath.contains("akka"))) + .map(_.filterNot(_.getCanonicalPath.contains("deploy"))) + .map(_.filterNot(_.getCanonicalPath.contains("network"))) + .map(_.filterNot(_.getCanonicalPath.contains("executor"))) + .map(_.filterNot(_.getCanonicalPath.contains("python"))) + .map(_.filterNot(_.getCanonicalPath.contains("collection"))) + }, + + // Javadoc options: create a window title, and group key packages on index page + javacOptions in doc := Seq( + "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", + "-public", + "-group", "Core Java API", packageList("api.java", "api.java.function"), + "-group", "Spark Streaming", packageList( + "streaming.api.java", "streaming.flume", "streaming.kafka", + "streaming.mqtt", "streaming.twitter", "streaming.zeromq" + ), + "-group", "MLlib", packageList( + "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg", + "mllib.linalg.distributed", "mllib.optimization", "mllib.rdd", "mllib.recommendation", + "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration", + "mllib.tree.impurity", "mllib.tree.model", "mllib.util" + ), + "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"), + "-noqualifier", "java.lang" ) ) +} - def zeromqSettings() = sharedSettings ++ Seq( - name := "spark-streaming-zeromq", - previousArtifact := sparkPreviousArtifact("spark-streaming-zeromq"), - libraryDependencies ++= Seq( - "org.spark-project.akka" %% "akka-zeromq" % akkaVersion +object TestSettings { + import BuildCommons._ + + lazy val settings = Seq ( + // Fork new JVMs for tests and set Java options for those + fork := true, + javaOptions in Test += "-Dspark.home=" + sparkHome, + javaOptions in Test += "-Dspark.testing=1", + javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", + javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") + .map { case (k,v) => s"-D$k=$v" }.toSeq, + javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" + .split(" ").toSeq, + javaOptions += "-Xmx3g", + + // Show full stack trace and duration in test cases. + testOptions in Test += Tests.Argument("-oDF"), + testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), + // Enable Junit testing. + libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test", + // Only allow one test at a time, even across projects, since they run in the same JVM + parallelExecution in Test := false, + concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), + // Remove certain packages from Scaladoc + scalacOptions in (Compile, doc) := Seq( + "-groups", + "-skip-packages", Seq( + "akka", + "org.apache.spark.api.python", + "org.apache.spark.network", + "org.apache.spark.deploy", + "org.apache.spark.util.collection" + ).mkString(":"), + "-doc-title", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " ScalaDoc" ) ) - def mqttSettings() = streamingSettings ++ Seq( - name := "spark-streaming-mqtt", - previousArtifact := sparkPreviousArtifact("spark-streaming-mqtt"), - libraryDependencies ++= Seq("org.eclipse.paho" % "mqtt-client" % "0.4.0") - ) } diff --git a/project/build.properties b/project/build.properties index bcde13f4362a7..c12ef652adfcb 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.2 +sbt.version=0.13.5 diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index e9fba641eb8a1..3ef2d5451da0d 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -24,8 +24,10 @@ import sbt.Keys._ * becomes available for scalastyle sbt plugin. */ object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle) + lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git") + // There is actually no need to publish this artifact. def styleSettings = Defaults.defaultSettings ++ Seq ( name := "spark-style", diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index eb5dbb8de2b39..4fda2a9b950b8 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -243,10 +243,10 @@ def save_function(self, obj, name=None, pack=struct.pack): # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule == None: + if islambda(obj) or obj.func_code.co_filename == '' or themodule is None: #Force server to import modules that have been imported in main modList = None - if themodule == None and not self.savedForceImports: + if themodule is None and not self.savedForceImports: mainmod = sys.modules['__main__'] if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): modList = list(mainmod.___pyc_forcedImports__) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 8eff4a242a529..60fc6ba7c52c2 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -30,7 +30,7 @@ u'local' >>> sc.appName u'My app' ->>> sc.sparkHome == None +>>> sc.sparkHome is None True >>> conf = SparkConf(loadDefaults=False) @@ -116,7 +116,7 @@ def setSparkHome(self, value): def setExecutorEnv(self, key=None, value=None, pairs=None): """Set an environment variable to be passed to executors.""" - if (key != None and pairs != None) or (key == None and pairs == None): + if (key is not None and pairs is not None) or (key is None and pairs is None): raise Exception("Either pass one key-value pair or a list of pairs") elif key != None: self._jconf.setExecutorEnv(key, value) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 0dbead4415b02..2a17127a7e0f9 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -56,7 +56,7 @@ def preexec_func(): (stdout, _) = proc.communicate() exit_code = proc.poll() error_msg = "Launching GatewayServer failed" - error_msg += " with exit code %d!" % exit_code if exit_code else "! " + error_msg += " with exit code %d! " % exit_code if exit_code else "! " error_msg += "(Warning: unexpected output detected.)\n\n" error_msg += gateway_port + stdout raise Exception(error_msg) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f64f48e3a4c9c..0c35c666805dd 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -69,16 +69,19 @@ def _extract_concise_traceback(): file, line, fun, what = tb[0] return callsite(function=fun, file=file, linenum=line) sfile, sline, sfun, swhat = tb[first_spark_frame] - ufile, uline, ufun, uwhat = tb[first_spark_frame-1] + ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] return callsite(function=sfun, file=ufile, linenum=uline) _spark_stack_depth = 0 + class _JavaStackTrace(object): + def __init__(self, sc): tb = _extract_concise_traceback() if tb is not None: - self._traceback = "%s at %s:%s" % (tb.function, tb.file, tb.linenum) + self._traceback = "%s at %s:%s" % ( + tb.function, tb.file, tb.linenum) else: self._traceback = "Error! Could not extract traceback info" self._context = sc @@ -95,7 +98,9 @@ def __exit__(self, type, value, tb): if _spark_stack_depth == 0: self._context._jsc.setCallSite(None) + class MaxHeapQ(object): + """ An implementation of MaxHeap. >>> import pyspark.rdd @@ -117,14 +122,14 @@ class MaxHeapQ(object): """ def __init__(self, maxsize): - # we start from q[1], this makes calculating children as trivial as 2 * k + # We start from q[1], so its children are always 2 * k self.q = [0] self.maxsize = maxsize def _swim(self, k): - while (k > 1) and (self.q[k/2] < self.q[k]): - self._swap(k, k/2) - k = k/2 + while (k > 1) and (self.q[k / 2] < self.q[k]): + self._swap(k, k / 2) + k = k / 2 def _swap(self, i, j): t = self.q[i] @@ -162,7 +167,9 @@ def _replaceRoot(self, value): self.q[1] = value self._sink(1) + class RDD(object): + """ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, partitioned collection of elements that can be @@ -257,7 +264,8 @@ def map(self, f, preservesPartitioning=False): >>> sorted(rdd.map(lambda x: (x, 1)).collect()) [('a', 1), ('b', 1), ('c', 1)] """ - def func(split, iterator): return imap(f, iterator) + def func(split, iterator): + return imap(f, iterator) return PipelinedRDD(self, func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -271,7 +279,8 @@ def flatMap(self, f, preservesPartitioning=False): >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ - def func(s, iterator): return chain.from_iterable(imap(f, iterator)) + def func(s, iterator): + return chain.from_iterable(imap(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -283,7 +292,8 @@ def mapPartitions(self, f, preservesPartitioning=False): >>> rdd.mapPartitions(f).collect() [3, 7] """ - def func(s, iterator): return f(iterator) + def func(s, iterator): + return f(iterator) return self.mapPartitionsWithIndex(func) def mapPartitionsWithIndex(self, f, preservesPartitioning=False): @@ -311,17 +321,17 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False): 6 """ warnings.warn("mapPartitionsWithSplit is deprecated; " - "use mapPartitionsWithIndex instead", DeprecationWarning, stacklevel=2) + "use mapPartitionsWithIndex instead", DeprecationWarning, stacklevel=2) return self.mapPartitionsWithIndex(f, preservesPartitioning) def getNumPartitions(self): - """ - Returns the number of partitions in RDD - >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> rdd.getNumPartitions() - 2 - """ - return self._jrdd.partitions().size() + """ + Returns the number of partitions in RDD + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> rdd.getNumPartitions() + 2 + """ + return self._jrdd.partitions().size() def filter(self, f): """ @@ -331,7 +341,8 @@ def filter(self, f): >>> rdd.filter(lambda x: x % 2 == 0).collect() [2, 4] """ - def func(iterator): return ifilter(f, iterator) + def func(iterator): + return ifilter(f, iterator) return self.mapPartitions(func) def distinct(self): @@ -391,9 +402,11 @@ def takeSample(self, withReplacement, num, seed=None): maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) if num > maxSampleSize: - raise ValueError("Sample size cannot be greater than %d." % maxSampleSize) + raise ValueError( + "Sample size cannot be greater than %d." % maxSampleSize) - fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement) + fraction = RDD._computeFractionForSampleSize( + num, initialCount, withReplacement) samples = self.sample(withReplacement, fraction, seed).collect() # If the first sample didn't turn out large enough, keep trying to take samples; @@ -499,17 +512,17 @@ def __add__(self, other): raise TypeError return self.union(other) - def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): + def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. - + # noqa >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)]) >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect() - [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] + [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5),...('white', 9), ('whose', 6)] """ if numPartitions is None: numPartitions = self._defaultReducePartitions() @@ -521,10 +534,12 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): # number of (key, value) pairs falling into them if numPartitions > 1: rddSize = self.count() - maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + # constant from Spark's RangePartitioner + maxSampleSize = numPartitions * 20.0 fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = self.sample(False, fraction, 1).map( + lambda (k, v): k).collect() samples = sorted(samples, reverse=(not ascending), key=keyfunc) # we have numPartitions many parts but one of the them has @@ -540,13 +555,13 @@ def rangePartitionFunc(k): if ascending: return p else: - return numPartitions-1-p + return numPartitions - 1 - p def mapFunc(iterator): yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) - .mapPartitions(mapFunc,preservesPartitioning=True) + .mapPartitions(mapFunc, preservesPartitioning=True) .flatMap(lambda x: x, preservesPartitioning=True)) def sortBy(self, keyfunc, ascending=True, numPartitions=None): @@ -570,7 +585,8 @@ def glom(self): >>> sorted(rdd.glom().collect()) [[1, 2], [3, 4]] """ - def func(iterator): yield list(iterator) + def func(iterator): + yield list(iterator) return self.mapPartitions(func) def cartesian(self, other): @@ -607,7 +623,9 @@ def pipe(self, command, env={}): ['1', '2', '', '3'] """ def func(iterator): - pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + pipe = Popen( + shlex.split(command), env=env, stdin=PIPE, stdout=PIPE) + def pipe_objs(out): for obj in iterator: out.write(str(obj).rstrip('\n') + '\n') @@ -646,7 +664,7 @@ def collect(self): Return a list that contains all of the elements in this RDD. """ with _JavaStackTrace(self.context) as st: - bytesInJava = self._jrdd.collect().iterator() + bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): @@ -736,7 +754,6 @@ def func(iterator): return self.mapPartitions(func).fold(zeroValue, combOp) - def max(self): """ Find the maximum item in this RDD. @@ -844,6 +861,7 @@ def countPartition(iterator): for obj in iterator: counts[obj] += 1 yield counts + def mergeMaps(m1, m2): for (k, v) in m2.iteritems(): m1[k] += v @@ -888,22 +906,22 @@ def takeOrdered(self, num, key=None): def topNKeyedElems(iterator, key_=None): q = MaxHeapQ(num) for k in iterator: - if key_ != None: + if key_ is not None: k = (key_(k), k) q.insert(k) yield q.getElements() def unKey(x, key_=None): - if key_ != None: + if key_ is not None: x = [i[1] for i in x] return x def merge(a, b): return next(topNKeyedElems(a + b)) - result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge) + result = self.mapPartitions( + lambda i: topNKeyedElems(i, key)).reduce(merge) return sorted(unKey(result, key), key=key) - def take(self, num): """ Take the first num elements of the RDD. @@ -947,7 +965,8 @@ def takeUpToNumLeft(iterator): yield next(iterator) taken += 1 - p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) + p = range( + partsScanned, min(partsScanned + numPartsToTry, totalParts)) res = self.context.runJob(self, takeUpToNumLeft, p, True) items += res @@ -977,7 +996,7 @@ def saveAsPickleFile(self, path, batchSize=10): [1, 2, 'rdd', 'spark'] """ self._reserialize(BatchedSerializer(PickleSerializer(), - batchSize))._jrdd.saveAsObjectFile(path) + batchSize))._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path): """ @@ -1075,6 +1094,7 @@ def reducePartition(iterator): for (k, v) in iterator: m[k] = v if k not in m else func(m[k], v) yield m + def mergeMaps(m1, m2): for (k, v) in m2.iteritems(): m1[k] = v if k not in m1 else func(m1[k], v) @@ -1162,6 +1182,7 @@ def partitionBy(self, numPartitions, partitionFunc=None): # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. outputSerializer = self.ctx._unbatched_serializer + def add_shuffle_key(split, iterator): buckets = defaultdict(list) @@ -1174,7 +1195,8 @@ def add_shuffle_key(split, iterator): keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True with _JavaStackTrace(self.context) as st: - pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + pairRDD = self.ctx._jvm.PairwiseRDD( + keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() @@ -1213,6 +1235,7 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, """ if numPartitions is None: numPartitions = self._defaultReducePartitions() + def combineLocally(iterator): combiners = {} for x in iterator: @@ -1224,10 +1247,11 @@ def combineLocally(iterator): return combiners.iteritems() locally_combined = self.mapPartitions(combineLocally) shuffled = locally_combined.partitionBy(numPartitions) + def _mergeCombiners(iterator): combiners = {} for (k, v) in iterator: - if not k in combiners: + if k not in combiners: combiners[k] = v else: combiners[k] = mergeCombiners(combiners[k], v) @@ -1236,17 +1260,19 @@ def _mergeCombiners(iterator): def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None): """ - Aggregate the values of each key, using given combine functions and a neutral "zero value". - This function can return a different result type, U, than the type of the values in this RDD, - V. Thus, we need one operation for merging a V into a U and one operation for merging two U's, - The former operation is used for merging values within a partition, and the latter is used - for merging values between partitions. To avoid memory allocation, both of these functions are + Aggregate the values of each key, using given combine functions and a neutral + "zero value". This function can return a different result type, U, than the type + of the values in this RDD, V. Thus, we need one operation for merging a V into + a U and one operation for merging two U's, The former operation is used for merging + values within a partition, and the latter is used for merging values between + partitions. To avoid memory allocation, both of these functions are allowed to modify and return their first argument instead of creating a new U. """ def createZero(): - return copy.deepcopy(zeroValue) + return copy.deepcopy(zeroValue) - return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) + return self.combineByKey( + lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions) def foldByKey(self, zeroValue, func, numPartitions=None): """ @@ -1261,11 +1287,10 @@ def foldByKey(self, zeroValue, func, numPartitions=None): [('a', 2), ('b', 1)] """ def createZero(): - return copy.deepcopy(zeroValue) + return copy.deepcopy(zeroValue) return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) - # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ @@ -1292,7 +1317,7 @@ def mergeCombiners(a, b): return a + b return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: ResultIterable(x)) + numPartitions).mapValues(lambda x: ResultIterable(x)) # TODO: add tests def flatMapValues(self, f): @@ -1362,7 +1387,8 @@ def subtractByKey(self, other, numPartitions=None): >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] """ - filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0 + def filter_func((key, vals)): + return len(vals[0]) > 0 and len(vals[1]) == 0 map_func = lambda (key, vals): [(key, val) for val in vals[0]] return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func) @@ -1375,8 +1401,9 @@ def subtract(self, other, numPartitions=None): >>> sorted(x.subtract(y).collect()) [('a', 1), ('b', 4), ('b', 5)] """ - rdd = other.map(lambda x: (x, True)) # note: here 'True' is just a placeholder - return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) # note: here 'True' is just a placeholder + # note: here 'True' is just a placeholder + rdd = other.map(lambda x: (x, True)) + return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0]) def keyBy(self, f): """ @@ -1434,7 +1461,7 @@ def zip(self, other): """ pairRDD = self._jrdd.zip(other._jrdd) deserializer = PairDeserializer(self._jrdd_deserializer, - other._jrdd_deserializer) + other._jrdd_deserializer) return RDD(pairRDD, self.ctx, deserializer) def name(self): @@ -1503,7 +1530,9 @@ def _defaultReducePartitions(self): # keys in the pairs. This could be an expensive operation, since those # hashes aren't retained. + class PipelinedRDD(RDD): + """ Pipelined maps: >>> rdd = sc.parallelize([1, 2, 3, 4]) @@ -1519,6 +1548,7 @@ class PipelinedRDD(RDD): >>> rdd.flatMap(lambda x: [x, x]).reduce(add) 20 """ + def __init__(self, prev, func, preservesPartitioning=False): if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): # This transformation is the first in its stage: @@ -1528,6 +1558,7 @@ def __init__(self, prev, func, preservesPartitioning=False): self._prev_jrdd_deserializer = prev._jrdd_deserializer else: prev_func = prev.func + def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) self.func = pipeline_func @@ -1560,11 +1591,13 @@ def _jrdd(self): env = MapConverter().convert(self.ctx.environment, self.ctx._gateway._gateway_client) includes = ListConverter().convert(self.ctx._python_includes, - self.ctx._gateway._gateway_client) + self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - bytearray(pickled_command), env, includes, self.preservesPartitioning, - self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, - class_tag) + bytearray(pickled_command), + env, includes, self.preservesPartitioning, + self.ctx.pythonExec, + broadcast_vars, self.ctx._javaAccumulator, + class_tag) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val @@ -1579,7 +1612,8 @@ def _test(): # The small batch size here ensures that we see multiple batches, # even in these small test examples: globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 845a267e311c5..122bc38b03b0c 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -82,7 +82,7 @@ def getPoissonSample(self, split, mean): return (num_arrivals - 1) def shuffle(self, vals): - if self._random == None: + if self._random is None: self.initRandomGenerator(0) # this should only ever called on the master so # the split does not matter diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index ebd714db7a918..2ce5409cd67c2 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -35,7 +35,7 @@ from pyspark.storagelevel import StorageLevel # this is the equivalent of ADD_JARS -add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None +add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") is not None else None if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) @@ -55,7 +55,7 @@ platform.python_build()[1])) print("SparkContext available as sc.") -if add_files != None: +if add_files is not None: print("Adding files: [%s]" % ", ".join(add_files)) # The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP, diff --git a/repl/pom.xml b/repl/pom.xml index 4a66408ef3d2d..4ebb1b82f0e8c 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -32,6 +32,7 @@ http://spark.apache.org/ + repl /usr/share/spark root diff --git a/sbt/sbt b/sbt/sbt index 9de265bd07dcb..1b1aa1483a829 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -72,6 +72,7 @@ Usage: $script_name [options] -J-X pass option -X directly to the java runtime (-J is stripped) -S-X add -X to sbt's scalacOptions (-J is stripped) + -PmavenProfiles Enable a maven profile for the build. In the case of duplicated or conflicting options, the order above shows precedence: JAVA_OPTS lowest, command line options highest. diff --git a/sbt/sbt-launch-lib.bash b/sbt/sbt-launch-lib.bash index 64e40a88206be..c91fecf024ad4 100755 --- a/sbt/sbt-launch-lib.bash +++ b/sbt/sbt-launch-lib.bash @@ -16,6 +16,7 @@ declare -a residual_args declare -a java_args declare -a scalac_args declare -a sbt_commands +declare -a maven_profiles if test -x "$JAVA_HOME/bin/java"; then echo -e "Using $JAVA_HOME as default JAVA_HOME." @@ -87,6 +88,13 @@ addJava () { dlog "[addJava] arg = '$1'" java_args=( "${java_args[@]}" "$1" ) } + +enableProfile () { + dlog "[enableProfile] arg = '$1'" + maven_profiles=( "${maven_profiles[@]}" "$1" ) + export SBT_MAVEN_PROFILES="${maven_profiles[@]}" +} + addSbt () { dlog "[addSbt] arg = '$1'" sbt_commands=( "${sbt_commands[@]}" "$1" ) @@ -141,7 +149,8 @@ process_args () { -java-home) require_arg path "$1" "$2" && java_cmd="$2/bin/java" && export JAVA_HOME=$2 && shift 2 ;; -D*) addJava "$1" && shift ;; - -J*) addJava "${1:2}" && shift ;; + -J*) addJava "${1:2}" && shift ;; + -P*) enableProfile "$1" && shift ;; *) addResidual "$1" && shift ;; esac done diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 01d7b569080ea..6decde3fcd62d 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -31,6 +31,9 @@ jar Spark Project Catalyst http://spark.apache.org/ + + catalyst + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index e5653c5b14ac1..a34b236c8ac6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -120,7 +120,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val WHERE = Keyword("WHERE") protected val INTERSECT = Keyword("INTERSECT") protected val EXCEPT = Keyword("EXCEPT") - + protected val SUBSTR = Keyword("SUBSTR") + protected val SUBSTRING = Keyword("SUBSTRING") // Use reflection to find the reserved words defined in this class. protected val reservedWords = @@ -316,6 +317,12 @@ class SqlParser extends StandardTokenParsers with PackratParsers { IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { case c ~ "," ~ t ~ "," ~ f => If(c,t,f) } | + (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ { + case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE)) + } | + (SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ { + case s ~ "," ~ p ~ "," ~ l => Substring(s,p,l) + } | ident ~ "(" ~ repsep(expression, ",") <~ ")" ^^ { case udfName ~ _ ~ exprs => UnresolvedFunction(udfName, exprs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 0d05d9808b407..616f1e2ecb60f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -113,11 +113,12 @@ trait OverrideCatalog extends Catalog { alias: Option[String] = None): LogicalPlan = { val (dbName, tblName) = processDatabaseAndTableName(databaseName, tableName) val overriddenTable = overrides.get((dbName, tblName)) + val tableWithQualifers = overriddenTable.map(r => Subquery(tblName, r)) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. val withAlias = - overriddenTable.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) + tableWithQualifers.map(r => alias.map(a => Subquery(a, r)).getOrElse(r)) withAlias.getOrElse(super.lookupRelation(dbName, tblName, alias)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index a6ce90854dcb4..22941edef2d46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * of itself with globally unique expression ids. */ trait MultiInstanceRelation { - def newInstance: this.type + def newInstance(): this.type } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 347471cebdc7e..97fc3a3b14b88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import org.apache.spark.sql.catalyst.types.DataType -import org.apache.spark.sql.catalyst.types.StringType -import org.apache.spark.sql.catalyst.types.BooleanType +import scala.collection.IndexedSeqOptimized +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.types.{BinaryType, BooleanType, DataType, StringType} + trait StringRegexExpression { self: BinaryExpression => @@ -32,7 +33,7 @@ trait StringRegexExpression { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - def nullable: Boolean = true + def nullable: Boolean = left.nullable || right.nullable def dataType: DataType = BooleanType // try cache the pattern for Literal @@ -157,19 +158,13 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def toString() = s"Lower($child)" } -/** A base class for functions that compare two strings, returning a boolean. */ -abstract class StringComparison extends Expression { - self: Product => +/** A base trait for functions that compare two strings, returning a boolean. */ +trait StringComparison { + self: BinaryExpression => type EvaluatedType = Any - def left: Expression - def right: Expression - - override def references = children.flatMap(_.references).toSet - override def children = left :: right :: Nil - - override def nullable: Boolean = true + def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType def compare(l: String, r: String): Boolean @@ -184,26 +179,102 @@ abstract class StringComparison extends Expression { } } + def symbol: String = nodeName + override def toString() = s"$nodeName($left, $right)" } /** * A function that returns true if the string `left` contains the string `right`. */ -case class Contains(left: Expression, right: Expression) extends StringComparison { +case class Contains(left: Expression, right: Expression) + extends BinaryExpression with StringComparison { override def compare(l: String, r: String) = l.contains(r) } /** * A function that returns true if the string `left` starts with the string `right`. */ -case class StartsWith(left: Expression, right: Expression) extends StringComparison { +case class StartsWith(left: Expression, right: Expression) + extends BinaryExpression with StringComparison { def compare(l: String, r: String) = l.startsWith(r) } /** * A function that returns true if the string `left` ends with the string `right`. */ -case class EndsWith(left: Expression, right: Expression) extends StringComparison { +case class EndsWith(left: Expression, right: Expression) + extends BinaryExpression with StringComparison { def compare(l: String, r: String) = l.endsWith(r) } + +/** + * A function that takes a substring of its first argument starting at a given position. + * Defined for String and Binary types. + */ +case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression { + + type EvaluatedType = Any + + override def foldable = str.foldable && pos.foldable && len.foldable + + def nullable: Boolean = str.nullable || pos.nullable || len.nullable + def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") + } + if (str.dataType == BinaryType) str.dataType else StringType + } + + def references = children.flatMap(_.references).toSet + + override def children = str :: pos :: len :: Nil + + @inline + def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) + (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { + val len = str.length + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + + val start = startPos match { + case pos if pos > 0 => pos - 1 + case neg if neg < 0 => len + neg + case _ => 0 + } + + val end = sliceLen match { + case max if max == Integer.MAX_VALUE => max + case x => start + x + } + + str.slice(start, end) + } + + override def eval(input: Row): Any = { + val string = str.eval(input) + + val po = pos.eval(input) + val ln = len.eval(input) + + if ((string == null) || (po == null) || (ln == null)) { + null + } else { + val start = po.asInstanceOf[Int] + val length = ln.asInstanceOf[Int] + + string match { + case ba: Array[Byte] => slice(ba, start, length) + case other => slice(other.toString, start, length) + } + } + } + + override def toString = len match { + case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" + case _ => s"SUBSTR($str, $pos, $len)" + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f0904f59d028f..c65987b7120b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -123,6 +123,7 @@ object LikeSimplification extends Rule[LogicalPlan] { val startsWith = "([^_%]+)%".r val endsWith = "%([^_%]+)".r val contains = "%([^_%]+)%".r + val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") => @@ -131,6 +132,8 @@ object LikeSimplification extends Rule[LogicalPlan] { EndsWith(l, Literal(pattern)) case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") => Contains(l, Literal(pattern)) + case Like(l, Literal(equalTo(pattern), StringType)) => + EqualTo(l, Literal(pattern)) } } @@ -150,11 +153,13 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) - case e @ Coalesce(children) => { - val newChildren = children.filter(c => c match { + + // For Coalesce, remove null literals. + case e @ Coalesce(children) => + val newChildren = children.filter { case Literal(null, _) => false case _ => true - }) + } if (newChildren.length == 0) { Literal(null, e.dataType) } else if (newChildren.length == 1) { @@ -162,12 +167,11 @@ object NullPropagation extends Rule[LogicalPlan] { } else { Coalesce(newChildren) } - } - case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue - case e @ In(Literal(v, _), list) if (list.exists(c => c match { - case Literal(candidate, _) if candidate == v => true - case _ => false - })) => Literal(true, BooleanType) + + case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) + case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) + case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) @@ -184,6 +188,11 @@ object NullPropagation extends Rule[LogicalPlan] { case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) case _ => e } + case e: StringComparison => e.children match { + case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) + case left :: Literal(null, _) :: Nil => Literal(null, e.dataType) + case _ => e + } } } } @@ -195,9 +204,19 @@ object NullPropagation extends Rule[LogicalPlan] { object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - // Skip redundant folding of literals. + // Skip redundant folding of literals. This rule is technically not necessary. Placing this + // here avoids running the next rule for Literal values, which would create a new Literal + // object and running eval unnecessarily. case l: Literal => l + + // Fold expressions that are foldable. case e if e.foldable => Literal(e.eval(null), e.dataType) + + // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. + case In(Literal(v, _), list) if list.exists { + case Literal(candidate, _) if candidate == v => true + case _ => false + } => Literal(true, BooleanType) } } } @@ -227,6 +246,9 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (l, Literal(false, BooleanType)) => l case (_, _) => or } + + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". + case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } } } @@ -248,12 +270,12 @@ object CombineFilters extends Rule[LogicalPlan] { */ object SimplifyFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(Literal(true, BooleanType), child) => - child - case Filter(Literal(null, _), child) => - LocalRelation(child.output) - case Filter(Literal(false, BooleanType), child) => - LocalRelation(child.output) + // If the filter condition always evaluate to true, remove the filter. + case Filter(Literal(true, BooleanType), child) => child + // If the filter condition always evaluate to null or false, + // replace the input with an empty relation. + case Filter(Literal(null, _), child) => LocalRelation(child.output, data = Seq.empty) + case Filter(Literal(false, BooleanType), child) => LocalRelation(child.output, data = Seq.empty) } } @@ -295,7 +317,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. - * @returns (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) + * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index bb77bccf86176..cd4b5e9c1b529 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,17 +19,20 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp -import scala.util.parsing.combinator.RegexParsers - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} +import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.util.Utils /** - * + * A JVM-global lock that should be used to prevent thread safety issues when using things in + * scala.reflect.*. Note that Scala Reflection API is made thread-safe in 2.11, but not yet for + * 2.10.* builds. See SI-6240 for more details. */ +protected[catalyst] object ScalaReflectionLock + object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = "StringType" ^^^ StringType | @@ -62,7 +65,6 @@ object DataType extends RegexParsers { "true" ^^^ true | "false" ^^^ false - protected lazy val structType: Parser[DataType] = "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { case fields => new StructType(fields) @@ -106,7 +108,7 @@ abstract class NativeType extends DataType { @transient val tag: TypeTag[JvmType] val ordering: Ordering[JvmType] - @transient val classTag = { + @transient val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) } @@ -114,22 +116,24 @@ abstract class NativeType extends DataType { case object StringType extends NativeType with PrimitiveType { type JvmType = String - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val ordering = implicitly[Ordering[JvmType]] } + case object BinaryType extends DataType with PrimitiveType { type JvmType = Array[Byte] } + case object BooleanType extends NativeType with PrimitiveType { type JvmType = Boolean - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val ordering = implicitly[Ordering[JvmType]] } case object TimestampType extends NativeType { type JvmType = Timestamp - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) @@ -159,7 +163,7 @@ abstract class IntegralType extends NumericType { case object LongType extends IntegralType { type JvmType = Long - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Long]] val integral = implicitly[Integral[Long]] val ordering = implicitly[Ordering[JvmType]] @@ -167,7 +171,7 @@ case object LongType extends IntegralType { case object IntegerType extends IntegralType { type JvmType = Int - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Int]] val integral = implicitly[Integral[Int]] val ordering = implicitly[Ordering[JvmType]] @@ -175,7 +179,7 @@ case object IntegerType extends IntegralType { case object ShortType extends IntegralType { type JvmType = Short - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Short]] val integral = implicitly[Integral[Short]] val ordering = implicitly[Ordering[JvmType]] @@ -183,7 +187,7 @@ case object ShortType extends IntegralType { case object ByteType extends IntegralType { type JvmType = Byte - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Byte]] val integral = implicitly[Integral[Byte]] val ordering = implicitly[Ordering[JvmType]] @@ -202,7 +206,7 @@ abstract class FractionalType extends NumericType { case object DecimalType extends FractionalType { type JvmType = BigDecimal - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[BigDecimal]] val fractional = implicitly[Fractional[BigDecimal]] val ordering = implicitly[Ordering[JvmType]] @@ -210,7 +214,7 @@ case object DecimalType extends FractionalType { case object DoubleType extends FractionalType { type JvmType = Double - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Double]] val fractional = implicitly[Fractional[Double]] val ordering = implicitly[Ordering[JvmType]] @@ -218,7 +222,7 @@ case object DoubleType extends FractionalType { case object FloatType extends FractionalType { type JvmType = Float - @transient lazy val tag = typeTag[JvmType] + @transient lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } val numeric = implicitly[Numeric[Float]] val fractional = implicitly[Fractional[Float]] val ordering = implicitly[Ordering[JvmType]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 84d72814778ba..73f546455b67f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -466,5 +466,81 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 === c2, false, row) checkEvaluation(c1 !== c2, true, row) } -} + test("StringComparison") { + val row = new GenericRow(Array[Any]("abc", null)) + val c1 = 'a.string.at(0) + val c2 = 'a.string.at(1) + + checkEvaluation(Contains(c1, "b"), true, row) + checkEvaluation(Contains(c1, "x"), false, row) + checkEvaluation(Contains(c2, "b"), null, row) + checkEvaluation(Contains(c1, Literal(null, StringType)), null, row) + + checkEvaluation(StartsWith(c1, "a"), true, row) + checkEvaluation(StartsWith(c1, "b"), false, row) + checkEvaluation(StartsWith(c2, "a"), null, row) + checkEvaluation(StartsWith(c1, Literal(null, StringType)), null, row) + + checkEvaluation(EndsWith(c1, "c"), true, row) + checkEvaluation(EndsWith(c1, "b"), false, row) + checkEvaluation(EndsWith(c2, "b"), null, row) + checkEvaluation(EndsWith(c1, Literal(null, StringType)), null, row) + } + + test("Substring") { + val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + + val s = 'a.string.at(0) + + // substring from zero position with less-than-full length + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)), "ex", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(2, IntegerType)), "ex", row) + + // substring from zero position with full length + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(7, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(7, IntegerType)), "example", row) + + // substring from zero position with greater-than-full length + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(100, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(100, IntegerType)), "example", row) + + // substring from nonzero position with less-than-full length + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(2, IntegerType)), "xa", row) + + // substring from nonzero position with full length + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(6, IntegerType)), "xample", row) + + // substring from nonzero position with greater-than-full length + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(100, IntegerType)), "xample", row) + + // zero-length substring (within string bounds) + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(0, IntegerType)), "", row) + + // zero-length substring (beyond string bounds) + checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), "", row) + + // substring(null, _, _) -> null + checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(4, IntegerType)), null, new GenericRow(Array[Any](null))) + + // substring(_, null, _) -> null + checkEvaluation(Substring(s, Literal(null, IntegerType), Literal(4, IntegerType)), null, row) + + // substring(_, _, null) -> null + checkEvaluation(Substring(s, Literal(100, IntegerType), Literal(null, IntegerType)), null, row) + + // 2-arg substring from zero position + checkEvaluation(Substring(s, Literal(0, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + checkEvaluation(Substring(s, Literal(1, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "example", row) + + // 2-arg substring from nonzero position + checkEvaluation(Substring(s, Literal(2, IntegerType), Literal(Integer.MAX_VALUE, IntegerType)), "xample", row) + + val s_notNull = 'a.string.notNull.at(0) + + assert(Substring(s, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal(0, IntegerType), Literal(2, IntegerType)).nullable === false) + assert(Substring(s_notNull, Literal(null, IntegerType), Literal(2, IntegerType)).nullable === true) + assert(Substring(s_notNull, Literal(0, IntegerType), Literal(null, IntegerType)).nullable === true) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 0ff82064012a8..d607eed1bea89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -201,7 +201,15 @@ class ConstantFoldingSuite extends PlanTest { Like(Literal(null, StringType), "abc") as 'c13, Like("abc", Literal(null, StringType)) as 'c14, - Upper(Literal(null, StringType)) as 'c15) + Upper(Literal(null, StringType)) as 'c15, + + Substring(Literal(null, StringType), 0, 1) as 'c16, + Substring("abc", Literal(null, IntegerType), 1) as 'c17, + Substring("abc", 0, Literal(null, IntegerType)) as 'c18, + + Contains(Literal(null, StringType), "abc") as 'c19, + Contains("abc", Literal(null, StringType)) as 'c20 + ) val optimized = Optimize(originalQuery.analyze) @@ -228,8 +236,15 @@ class ConstantFoldingSuite extends PlanTest { Literal(null, BooleanType) as 'c13, Literal(null, BooleanType) as 'c14, - Literal(null, StringType) as 'c15) - .analyze + Literal(null, StringType) as 'c15, + + Literal(null, StringType) as 'c16, + Literal(null, StringType) as 'c17, + Literal(null, StringType) as 'c18, + + Literal(null, BooleanType) as 'c19, + Literal(null, BooleanType) as 'c20 + ).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala new file mode 100644 index 0000000000000..b10577c8001e2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules._ + +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + +class LikeSimplificationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Like Simplification", Once, + LikeSimplification) :: Nil + } + + val testRelation = LocalRelation('a.string) + + test("simplify Like into StartsWith") { + val originalQuery = + testRelation + .where(('a like "abc%") || ('a like "abc\\%")) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(StartsWith('a, "abc") || ('a like "abc\\%")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify Like into EndsWith") { + val originalQuery = + testRelation + .where('a like "%xyz") + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(EndsWith('a, "xyz")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify Like into Contains") { + val originalQuery = + testRelation + .where(('a like "%mn%") || ('a like "%mn\\%")) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(Contains('a, "mn") || ('a like "%mn\\%")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("simplify Like into EqualTo") { + val originalQuery = + testRelation + .where(('a like "") || ('a like "abc")) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = testRelation + .where(('a === "") || ('a === "abc")) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 8210fd1f210d1..c309c43804d97 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -31,6 +31,9 @@ jar Spark Project SQL http://spark.apache.org/ + + sql + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index b6fb46a3acc03..2b787e14f3f15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -52,7 +52,7 @@ trait SQLConf { /** ********************** SQLConf functionality methods ************ */ @transient - protected[sql] val settings = java.util.Collections.synchronizedMap( + private val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) def set(props: Properties): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 8bcfc7c064c2f..993d085c75089 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -256,6 +256,26 @@ class SchemaRDD( def unionAll(otherPlan: SchemaRDD) = new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) + /** + * Performs a relational except on two SchemaRDDs + * + * @param otherPlan the [[SchemaRDD]] that should be excepted from this one. + * + * @group Query + */ + def except(otherPlan: SchemaRDD): SchemaRDD = + new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan)) + + /** + * Performs a relational intersect on two SchemaRDDs + * + * @param otherPlan the [[SchemaRDD]] that should be intersected with this one. + * + * @group Query + */ + def intersect(otherPlan: SchemaRDD): SchemaRDD = + new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan)) + /** * Filters tuples using a function over the value of the specified column. * @@ -360,32 +380,32 @@ class SchemaRDD( val fields = structType.fields.map(field => (field.name, field.dataType)) val map: JMap[String, Any] = new java.util.HashMap row.zip(fields).foreach { - case (obj, (name, dataType)) => + case (obj, (attrName, dataType)) => dataType match { - case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct)) + case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) case array @ ArrayType(struct: StructType) => val arrayValues = obj match { case seq: Seq[Any] => seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava - case list: JList[Any] => + case list: JList[_] => list.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case set: JSet[Any] => + case set: JSet[_] => set.map(element => rowToMap(element.asInstanceOf[Row], struct)) - case array if array != null && array.getClass.isArray => - array.asInstanceOf[Array[Any]].map { + case arr if arr != null && arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map { element => rowToMap(element.asInstanceOf[Row], struct) } case other => other } - map.put(name, arrayValues) + map.put(attrName, arrayValues) case array: ArrayType => { val arrayValues = obj match { case seq: Seq[Any] => seq.asJava case other => other } - map.put(name, arrayValues) + map.put(attrName, arrayValues) } - case other => map.put(name, obj) + case other => map.put(attrName, obj) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index aff6ffe9f3478..8fbf13b8b0150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import java.util.{List => JList} + import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} @@ -96,6 +98,20 @@ class JavaSchemaRDD( this } + // Overridden actions from JavaRDDLike. + + override def collect(): JList[Row] = { + import scala.collection.JavaConversions._ + val arr: java.util.Collection[Row] = baseSchemaRDD.collect().toSeq.map(new Row(_)) + new java.util.ArrayList(arr) + } + + override def take(num: Int): JList[Row] = { + import scala.collection.JavaConversions._ + val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_)) + new java.util.ArrayList(arr) + } + // Transformations (return a new RDD) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index e1e4f24c6c66c..88901debbb4e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.columnar +import java.nio.ByteBuffer + +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -26,22 +29,19 @@ import org.apache.spark.SparkConf object InMemoryRelation { def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, child) + new InMemoryRelation(child.output, useCompression, child)() } private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, child: SparkPlan) + (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) extends LogicalPlan with MultiInstanceRelation { - override def children = Seq.empty - override def references = Set.empty - - override def newInstance() = - new InMemoryRelation(output.map(_.newInstance), useCompression, child).asInstanceOf[this.type] - - lazy val cachedColumnBuffers = { + // If the cached column buffers were not passed in, we calculate them in the constructor. + // As in Spark, the actual work of caching is lazy. + if (_cachedColumnBuffers == null) { val output = child.output val cached = child.execute().mapPartitions { iterator => val columnBuilders = output.map { attribute => @@ -62,10 +62,23 @@ private[sql] case class InMemoryRelation( }.cache() cached.setName(child.toString) - // Force the materialization of the cached RDD. - cached.count() - cached + _cachedColumnBuffers = cached + } + + + override def children = Seq.empty + + override def references = Set.empty + + override def newInstance() = { + new InMemoryRelation( + output.map(_.newInstance), + useCompression, + child)( + _cachedColumnBuffers).asInstanceOf[this.type] } + + def cachedColumnBuffers = _cachedColumnBuffers } private[sql] case class InMemoryColumnarTableScan( @@ -83,7 +96,11 @@ private[sql] case class InMemoryColumnarTableScan( new Iterator[Row] { // Find the ordinals of the requested columns. If none are requested, use the first. val requestedColumns = - if (attributes.isEmpty) Seq(0) else attributes.map(relation.output.indexOf(_)) + if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) val nextRow = new GenericMutableRow(columnAccessors.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7080074a69c07..c078e71fe0290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -247,8 +247,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Distinct(child) => - execution.Aggregate( - partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil + execution.Distinct(partial = false, + execution.Distinct(partial = true, planLater(child))) :: Nil case logical.Sort(sortExprs, child) => // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution. execution.Sort(sortExprs, global = true, planLater(child)):: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 97abd636ab5fb..966d8f95fc83c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.util.MutablePair /** @@ -248,6 +248,37 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } +/** + * :: DeveloperApi :: + * Computes the set of distinct input rows using a HashSet. + * @param partial when true the distinct operation is performed partially, per partition, without + * shuffling the data. + * @param child the input query plan. + */ +@DeveloperApi +case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { + override def output = child.output + + override def requiredChildDistribution = + if (partial) UnspecifiedDistribution :: Nil else ClusteredDistribution(child.output) :: Nil + + override def execute() = { + child.execute().mapPartitions { iter => + val hashSet = new scala.collection.mutable.HashSet[Row]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + if (!hashSet.contains(currentRow)) { + hashSet.add(currentRow.copy()) + } + } + + hashSet.iterator + } + } +} + /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index f6cbca96483e2..df80dfb98b93c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -204,14 +204,14 @@ private[sql] object JsonRDD extends Logging { case (key, value) => (s"`$key`", value) }.toSet keyValuePairs.flatMap { - case (key: String, struct: Map[String, Any]) => { - // The value associted with the key is an JSON object. - allKeysWithValueTypes(struct).map { + case (key: String, struct: Map[_, _]) => { + // The value associated with the key is an JSON object. + allKeysWithValueTypes(struct.asInstanceOf[Map[String, Any]]).map { case (k, dataType) => (s"$key.$k", dataType) } ++ Set((key, StructType(Nil))) } - case (key: String, array: List[Any]) => { - // The value associted with the key is an array. + case (key: String, array: List[_]) => { + // The value associated with the key is an array. typeOfArray(array) match { case ArrayType(StructType(Nil)) => { // The elements of this arrays are structs. @@ -235,12 +235,12 @@ private[sql] object JsonRDD extends Logging { * the parsing very slow. */ private def scalafy(obj: Any): Any = obj match { - case map: java.util.Map[String, Object] => + case map: java.util.Map[_, _] => // .map(identity) is used as a workaround of non-serializable Map // generated by .mapValues. // This issue is documented at https://issues.scala-lang.org/browse/SI-7005 map.toMap.mapValues(scalafy).map(identity) - case list: java.util.List[Object] => + case list: java.util.List[_] => list.toList.map(scalafy) case atom => atom } @@ -320,8 +320,8 @@ private[sql] object JsonRDD extends Logging { private def toString(value: Any): String = { value match { - case value: Map[String, Any] => toJsonObjectString(value) - case value: Seq[Any] => toJsonArrayString(value) + case value: Map[_, _] => toJsonObjectString(value.asInstanceOf[Map[String, Any]]) + case value: Seq[_] => toJsonArrayString(value) case value => Option(value).map(_.toString).orNull } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 889a408e3c393..de8fe2dae38f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -114,7 +114,7 @@ private[sql] object CatalystConverter { } } // All other primitive types use the default converter - case ctype: NativeType => { // note: need the type tag here! + case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) } case _ => throw new RuntimeException( @@ -229,9 +229,9 @@ private[parquet] class CatalystGroupConverter( this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) protected [parquet] val converters: Array[Converter] = - schema.map(field => - CatalystConverter.createConverter(field, schema.indexOf(field), this)) - .toArray + schema.zipWithIndex.map { + case (field, idx) => CatalystConverter.createConverter(field, idx, this) + }.toArray override val size = schema.size @@ -288,9 +288,9 @@ private[parquet] class CatalystPrimitiveRowConverter( new ParquetRelation.RowType(attributes.length)) protected [parquet] val converters: Array[Converter] = - schema.map(field => - CatalystConverter.createConverter(field, schema.indexOf(field), this)) - .toArray + schema.zipWithIndex.map { + case (field, idx) => CatalystConverter.createConverter(field, idx, this) + }.toArray override val size = schema.size diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ade823b51c9cd..ea74320d06c86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -17,27 +17,34 @@ package org.apache.spark.sql.parquet +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.Try + import java.io.IOException +import java.lang.{Long => JLong} import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, List => JList} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import parquet.hadoop.{ParquetRecordReader, ParquetInputFormat, ParquetOutputFormat} -import parquet.hadoop.api.ReadSupport +import parquet.hadoop._ +import parquet.hadoop.api.{InitContext, ReadSupport} +import parquet.hadoop.metadata.GlobalMetaData import parquet.hadoop.util.ContextUtil -import parquet.io.InvalidRecordException +import parquet.io.ParquetDecodingException import parquet.schema.MessageType -import org.apache.spark.{Logging, SerializableWritable, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} +import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** * Parquet table scan operator. Imports the file that backs the given @@ -55,16 +62,14 @@ case class ParquetTableScan( override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass( - job, - classOf[org.apache.spark.sql.parquet.RowReadSupport]) + ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) + val conf: Configuration = ContextUtil.getConfiguration(job) - val fileList = FileSystemHelper.listFiles(relation.path, conf) - // add all paths in the directory but skip "hidden" ones such - // as "_SUCCESS" and "_metadata" - for (path <- fileList if !path.getName.startsWith("_")) { - NewFileInputFormat.addInputPath(job, path) + val qualifiedPath = { + val path = new Path(relation.path) + path.getFileSystem(conf).makeQualified(path) } + NewFileInputFormat.addInputPath(job, qualifiedPath) // Store both requested and original schema in `Configuration` conf.set( @@ -87,7 +92,7 @@ case class ParquetTableScan( sc.newAPIHadoopRDD( conf, - classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat], + classOf[FilteringParquetRowInputFormat], classOf[Void], classOf[Row]) .map(_._2) @@ -122,14 +127,7 @@ case class ParquetTableScan( private def validateProjection(projection: Seq[Attribute]): Boolean = { val original: MessageType = relation.parquetSchema val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) - try { - original.checkContains(candidate) - true - } catch { - case e: InvalidRecordException => { - false - } - } + Try(original.checkContains(candidate)).isSuccess } } @@ -302,6 +300,11 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) */ private[parquet] class FilteringParquetRowInputFormat extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + + private var footers: JList[Footer] = _ + + private var fileStatuses= Map.empty[Path, FileStatus] + override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { @@ -318,6 +321,70 @@ private[parquet] class FilteringParquetRowInputFormat new ParquetRecordReader[Row](readSupport) } } + + override def getFooters(jobContext: JobContext): JList[Footer] = { + if (footers eq null) { + val statuses = listStatus(jobContext) + fileStatuses = statuses.map(file => file.getPath -> file).toMap + footers = getFooters(ContextUtil.getConfiguration(jobContext), statuses) + } + + footers + } + + // TODO Remove this method and related code once PARQUET-16 is fixed + // This method together with the `getFooters` method and the `fileStatuses` field are just used + // to mimic this PR: https://github.com/apache/incubator-parquet-mr/pull/17 + override def getSplits( + configuration: Configuration, + footers: JList[Footer]): JList[ParquetInputSplit] = { + + val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) + val minSplitSize: JLong = + Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) + if (maxSplitSize < 0 || minSplitSize < 0) { + throw new ParquetDecodingException( + s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + + s" minSplitSize = $minSplitSize") + } + + val getGlobalMetaData = + classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) + getGlobalMetaData.setAccessible(true) + val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] + + val readContext = getReadSupport(configuration).init( + new InitContext(configuration, + globalMetaData.getKeyValueMetaData(), + globalMetaData.getSchema())) + + val generateSplits = + classOf[ParquetInputFormat[_]].getDeclaredMethods.find(_.getName == "generateSplits").get + generateSplits.setAccessible(true) + + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + for (footer <- footers) { + val fs = footer.getFile.getFileSystem(configuration) + val file = footer.getFile + val fileStatus = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val parquetMetaData = footer.getParquetMetadata + val blocks = parquetMetaData.getBlocks + val fileBlockLocations = fs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen) + splits.addAll( + generateSplits.invoke( + null, + blocks, + fileBlockLocations, + fileStatus, + parquetMetaData.getFileMetaData, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + } + + splits + } } private[parquet] object FileSystemHelper { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 9cd5dc5bbd393..39294a3f4bf5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.parquet -import org.apache.hadoop.conf.Configuration +import java.util.{HashMap => JHashMap} +import org.apache.hadoop.conf.Configuration import parquet.column.ParquetProperties import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.api.ReadSupport.ReadContext import parquet.hadoop.api.{ReadSupport, WriteSupport} import parquet.io.api._ -import parquet.schema.{MessageType, MessageTypeParser} +import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import com.google.common.io.BaseEncoding /** * A `parquet.io.api.RecordMaterializer` for Rows. @@ -93,8 +92,8 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { configuration: Configuration, keyValueMetaData: java.util.Map[String, String], fileSchema: MessageType): ReadContext = { - var parquetSchema: MessageType = fileSchema - var metadata: java.util.Map[String, String] = new java.util.HashMap[String, String]() + var parquetSchema = fileSchema + val metadata = new JHashMap[String, String]() val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) if (requestedAttributes != null) { @@ -109,7 +108,7 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) } - return new ReadSupport.ReadContext(parquetSchema, metadata) + new ReadSupport.ReadContext(parquetSchema, metadata) } } @@ -132,13 +131,17 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] var attributes: Seq[Attribute] = null override def init(configuration: Configuration): WriteSupport.WriteContext = { - attributes = if (attributes == null) RowWriteSupport.getSchema(configuration) else attributes - + val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + val metadata = new JHashMap[String, String]() + metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) + + if (attributes == null) { + attributes = ParquetTypesConverter.convertFromString(origAttributesStr) + } + log.debug(s"write support initialized for requested schema $attributes") ParquetRelation.enableLogForwarding() - new WriteSupport.WriteContext( - ParquetTypesConverter.convertFromAttributes(attributes), - new java.util.HashMap[java.lang.String, java.lang.String]()) + new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) } override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { @@ -156,7 +159,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.startMessage() while(index < attributes.size) { // null values indicate optional fields but we do not check currently - if (record(index) != null && record(index) != Nil) { + if (record(index) != null) { writer.startField(attributes(index).name, index) writeValue(attributes(index).dataType, record(index)) writer.endField(attributes(index).name, index) @@ -167,7 +170,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } private[parquet] def writeValue(schema: DataType, value: Any): Unit = { - if (value != null && value != Nil) { + if (value != null) { schema match { case t @ ArrayType(_) => writeArray( t, @@ -184,13 +187,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } private[parquet] def writePrimitive(schema: PrimitiveType, value: Any): Unit = { - if (value != null && value != Nil) { + if (value != null) { schema match { case StringType => writer.addBinary( Binary.fromByteArray( value.asInstanceOf[String].getBytes("utf-8") ) ) + case BinaryType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) @@ -206,12 +211,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeStruct( schema: StructType, struct: CatalystConverter.StructScalaType[_]): Unit = { - if (struct != null && struct != Nil) { + if (struct != null) { val fields = schema.fields.toArray writer.startGroup() var i = 0 while(i < fields.size) { - if (struct(i) != null && struct(i) != Nil) { + if (struct(i) != null) { writer.startField(fields(i).name, i) writeValue(fields(i).dataType, struct(i)) writer.endField(fields(i).name, i) @@ -299,6 +304,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { record(index).asInstanceOf[String].getBytes("utf-8") ) ) + case BinaryType => writer.addBinary( + Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) case ShortType => writer.addInteger(record.getShort(index)) case LongType => writer.addLong(record.getLong(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 1dc58633a2a68..d4599da711254 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -58,7 +58,7 @@ private[sql] object ParquetTestData { """message myrecord { optional boolean myboolean; optional int32 myint; - optional binary mystring; + optional binary mystring (UTF8); optional int64 mylong; optional float myfloat; optional double mydouble; @@ -87,7 +87,7 @@ private[sql] object ParquetTestData { message myrecord { required boolean myboolean; required int32 myint; - required binary mystring; + required binary mystring (UTF8); required int64 mylong; required float myfloat; required double mydouble; @@ -119,14 +119,14 @@ private[sql] object ParquetTestData { // so that array types can be translated correctly. """ message AddressBook { - required binary owner; + required binary owner (UTF8); optional group ownerPhoneNumbers { - repeated binary array; + repeated binary array (UTF8); } optional group contacts { repeated group array { - required binary name; - optional binary phoneNumber; + required binary name (UTF8); + optional binary phoneNumber (UTF8); } } } @@ -181,16 +181,16 @@ private[sql] object ParquetTestData { required int32 x; optional group data1 { repeated group map { - required binary key; + required binary key (UTF8); required int32 value; } } required group data2 { repeated group map { - required binary key; + required binary key (UTF8); required group value { required int64 payload1; - optional binary payload2; + optional binary payload2 (UTF8); } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index f9046368e7ced..58370b955a5ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -22,6 +22,7 @@ import java.io.IOException import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} @@ -42,20 +43,22 @@ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass - def toPrimitiveDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { - case ParquetPrimitiveTypeName.BINARY => StringType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY => ArrayType(ByteType) - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - sys.error("Potential loss of precision: cannot convert INT96") - case _ => sys.error( - s"Unsupported parquet datatype $parquetType") - } + def toPrimitiveDataType(parquetType: ParquetPrimitiveType): DataType = + parquetType.getPrimitiveTypeName match { + case ParquetPrimitiveTypeName.BINARY + if parquetType.getOriginalType == ParquetOriginalType.UTF8 => StringType + case ParquetPrimitiveTypeName.BINARY => BinaryType + case ParquetPrimitiveTypeName.BOOLEAN => BooleanType + case ParquetPrimitiveTypeName.DOUBLE => DoubleType + case ParquetPrimitiveTypeName.FLOAT => FloatType + case ParquetPrimitiveTypeName.INT32 => IntegerType + case ParquetPrimitiveTypeName.INT64 => LongType + case ParquetPrimitiveTypeName.INT96 => + // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? + sys.error("Potential loss of precision: cannot convert INT96") + case _ => sys.error( + s"Unsupported parquet datatype $parquetType") + } /** * Converts a given Parquet `Type` into the corresponding @@ -104,7 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType.getPrimitiveTypeName) + toPrimitiveDataType(parquetType.asPrimitiveType) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -164,18 +167,17 @@ private[parquet] object ParquetTypesConverter extends Logging { * @return The name of the corresponding Parquet primitive type */ def fromPrimitiveDataType(ctype: DataType): - Option[ParquetPrimitiveTypeName] = ctype match { - case StringType => Some(ParquetPrimitiveTypeName.BINARY) - case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN) - case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE) - case ArrayType(ByteType) => - Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) - case FloatType => Some(ParquetPrimitiveTypeName.FLOAT) - case IntegerType => Some(ParquetPrimitiveTypeName.INT32) + Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match { + case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)) + case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None) + case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None) + case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None) + case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None) + case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None) // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetPrimitiveTypeName.INT32) - case ByteType => Some(ParquetPrimitiveTypeName.INT32) - case LongType => Some(ParquetPrimitiveTypeName.INT64) + case ShortType => Some(ParquetPrimitiveTypeName.INT32, None) + case ByteType => Some(ParquetPrimitiveTypeName.INT32, None) + case LongType => Some(ParquetPrimitiveTypeName.INT64, None) case _ => None } @@ -227,9 +229,10 @@ private[parquet] object ParquetTypesConverter extends Logging { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } val primitiveType = fromPrimitiveDataType(ctype) - if (primitiveType.isDefined) { - new ParquetPrimitiveType(repetition, primitiveType.get, name) - } else { + primitiveType.map { + case (primitiveType, originalType) => + new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) + }.getOrElse { ctype match { case ArrayType(elementType) => { val parquetElementType = fromDataType( @@ -237,7 +240,7 @@ private[parquet] object ParquetTypesConverter extends Logging { CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, nullable = false, inArray = true) - ConversionPatterns.listType(repetition, name, parquetElementType) + ConversionPatterns.listType(repetition, name, parquetElementType) } case StructType(structFields) => { val fields = structFields.map { @@ -365,20 +368,24 @@ private[parquet] object ParquetTypesConverter extends Logging { s"Expected $path for be a directory with Parquet files/metadata") } ParquetRelation.enableLogForwarding() - val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) - // if this is a new table that was just created we will find only the metadata file - if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { - ParquetFileReader.readFooter(conf, metadataPath) - } else { - // there may be one or more Parquet files in the given directory - val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) - // TODO: for now we assume that all footers (if there is more than one) have identical - // metadata; we may want to add a check here at some point - if (footers.size() == 0) { - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") - } - footers(0).getParquetMetadata + + val children = fs.listStatus(path).filterNot { + _.getPath.getName == FileOutputCommitter.SUCCEEDED_FILE_NAME } + + // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row + // groups. Since Parquet schema is replicated among all row groups, we only need to touch a + // single row group to read schema related metadata. Notice that we are making assumptions that + // all data in a single Parquet file have the same schema, which is normally true. + children + // Try any non-"_metadata" file first... + .find(_.getPath.getName != ParquetFileWriter.PARQUET_METADATA_FILE) + // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is + // empty, thus normally the "_metadata" file is expected to be fairly small). + .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) + .map(ParquetFileReader.readFooter(conf, _)) + .getOrElse( + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) } /** @@ -401,7 +408,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } else { val attributes = convertToAttributes( readMetaData(origPath, conf).getFileMetaData.getSchema) - log.warn(s"Falling back to schema conversion from Parquet types; result: $attributes") + log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 04ac008682f5f..68dae58728a2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -168,4 +168,25 @@ class DslQuerySuite extends QueryTest { test("zero count") { assert(emptyTableData.count() === 0) } + + test("except") { + checkAnswer( + lowerCaseData.except(upperCaseData), + (1, "a") :: + (2, "b") :: + (3, "c") :: + (4, "d") :: Nil) + checkAnswer(lowerCaseData.except(lowerCaseData), Nil) + checkAnswer(upperCaseData.except(upperCaseData), Nil) + } + + test("intersect") { + checkAnswer( + lowerCaseData.intersect(lowerCaseData), + (1, "a") :: + (2, "b") :: + (3, "c") :: + (4, "d") :: Nil) + checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 054b14f8f7ffa..e17ecc87fd52a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -36,36 +36,6 @@ class JoinSuite extends QueryTest { assert(planned.size === 1) } - test("plans broadcast hash join, given hints") { - - def mkTest(buildSide: BuildSide, leftTable: String, rightTable: String) = { - TestSQLContext.settings.synchronized { - TestSQLContext.set("spark.sql.join.broadcastTables", - s"${if (buildSide == BuildRight) rightTable else leftTable}") - val rdd = sql( s"""SELECT * FROM $leftTable JOIN $rightTable ON key = a""") - // Using `sparkPlan` because for relevant patterns in HashJoin to be - // matched, other strategies need to be applied. - val physical = rdd.queryExecution.sparkPlan - val bhj = physical.collect { case j: BroadcastHashJoin if j.buildSide == buildSide => j} - - assert(bhj.size === 1, "planner does not pick up hint to generate broadcast hash join") - checkAnswer( - rdd, - Seq( - (1, "1", 1, 1), - (1, "1", 1, 2), - (2, "2", 2, 1), - (2, "2", 2, 2), - (3, "3", 3, 1), - (3, "3", 3, 2) - )) - } - } - - mkTest(BuildRight, "testData", "testData2") - mkTest(BuildLeft, "testData", "testData2") - } - test("multiple-key equi-join is hash-join") { val x = testData2.as('x) val y = testData2.as('y) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 93792f698cfaf..08293f7f0ca30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -28,50 +28,46 @@ class SQLConfSuite extends QueryTest { val testVal = "test.val.0" test("programmatic ways of basic setting and getting") { - TestSQLContext.settings.synchronized { - clear() - assert(getOption(testKey).isEmpty) - assert(getAll.toSet === Set()) + clear() + assert(getOption(testKey).isEmpty) + assert(getAll.toSet === Set()) - set(testKey, testVal) - assert(get(testKey) == testVal) - assert(get(testKey, testVal + "_") == testVal) - assert(getOption(testKey) == Some(testVal)) - assert(contains(testKey)) + set(testKey, testVal) + assert(get(testKey) == testVal) + assert(get(testKey, testVal + "_") == testVal) + assert(getOption(testKey) == Some(testVal)) + assert(contains(testKey)) - // Tests SQLConf as accessed from a SQLContext is mutable after - // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.get(testKey) == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getOption(testKey) == Some(testVal)) - assert(TestSQLContext.contains(testKey)) + // Tests SQLConf as accessed from a SQLContext is mutable after + // the latter is initialized, unlike SparkConf inside a SparkContext. + assert(TestSQLContext.get(testKey) == testVal) + assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getOption(testKey) == Some(testVal)) + assert(TestSQLContext.contains(testKey)) - clear() - } + clear() } test("parse SQL set commands") { - TestSQLContext.settings.synchronized { - clear() - sql(s"set $testKey=$testVal") - assert(get(testKey, testVal + "_") == testVal) - assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + clear() + sql(s"set $testKey=$testVal") + assert(get(testKey, testVal + "_") == testVal) + assert(TestSQLContext.get(testKey, testVal + "_") == testVal) - sql("set mapred.reduce.tasks=20") - assert(get("mapred.reduce.tasks", "0") == "20") - sql("set mapred.reduce.tasks = 40") - assert(get("mapred.reduce.tasks", "0") == "40") + sql("set mapred.reduce.tasks=20") + assert(get("mapred.reduce.tasks", "0") == "20") + sql("set mapred.reduce.tasks = 40") + assert(get("mapred.reduce.tasks", "0") == "40") - val key = "spark.sql.key" - val vs = "val0,val_1,val2.3,my_table" - sql(s"set $key=$vs") - assert(get(key, "0") == vs) + val key = "spark.sql.key" + val vs = "val0,val_1,val2.3,my_table" + sql(s"set $key=$vs") + assert(get(key, "0") == vs) - sql(s"set $key=") - assert(get(key, "0") == "") + sql(s"set $key=") + assert(get(key, "0") == "") - clear() - } + clear() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index fa1f32f8a49a9..6736189c96d4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -36,6 +36,21 @@ class SQLQuerySuite extends QueryTest { "test") } + test("SPARK-2407 Added Parser of SQL SUBSTR()") { + checkAnswer( + sql("SELECT substr(tableName, 1, 2) FROM tableName"), + "te") + checkAnswer( + sql("SELECT substr(tableName, 3) FROM tableName"), + "st") + checkAnswer( + sql("SELECT substring(tableName, 1, 2) FROM tableName"), + "te") + checkAnswer( + sql("SELECT substring(tableName, 3) FROM tableName"), + "st") + } + test("index into array") { checkAnswer( sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), @@ -397,40 +412,38 @@ class SQLQuerySuite extends QueryTest { } test("SET commands semantics using sql()") { - TestSQLContext.settings.synchronized { - clear() - val testKey = "test.key.0" - val testVal = "test.val.0" - val nonexistentKey = "nonexistent" - - // "set" itself returns all config variables currently specified in SQLConf. - assert(sql("SET").collect().size == 0) - - // "set key=val" - sql(s"SET $testKey=$testVal") - checkAnswer( - sql("SET"), - Seq(Seq(testKey, testVal)) - ) - - sql(s"SET ${testKey + testKey}=${testVal + testVal}") - checkAnswer( - sql("set"), - Seq( - Seq(testKey, testVal), - Seq(testKey + testKey, testVal + testVal)) - ) - - // "set key" - checkAnswer( - sql(s"SET $testKey"), - Seq(Seq(testKey, testVal)) - ) - checkAnswer( - sql(s"SET $nonexistentKey"), - Seq(Seq(nonexistentKey, "")) - ) - clear() - } + clear() + val testKey = "test.key.0" + val testVal = "test.val.0" + val nonexistentKey = "nonexistent" + + // "set" itself returns all config variables currently specified in SQLConf. + assert(sql("SET").collect().size == 0) + + // "set key=val" + sql(s"SET $testKey=$testVal") + checkAnswer( + sql("SET"), + Seq(Seq(testKey, testVal)) + ) + + sql(s"SET ${testKey + testKey}=${testVal + testVal}") + checkAnswer( + sql("set"), + Seq( + Seq(testKey, testVal), + Seq(testKey + testKey, testVal + testVal)) + ) + + // "set key" + checkAnswer( + sql(s"SET $testKey"), + Seq(Seq(testKey, testVal)) + ) + checkAnswer( + sql(s"SET $nonexistentKey"), + Seq(Seq(nonexistentKey, "")) + ) + clear() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index dbf315947ff47..3c911e9a4e7b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -65,7 +65,8 @@ case class AllDataTypes( doubleField: Double, shortField: Short, byteField: Byte, - booleanField: Boolean) + booleanField: Boolean, + binaryField: Array[Byte]) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -76,9 +77,10 @@ case class AllDataTypesWithNonPrimitiveType( shortField: Short, byteField: Byte, booleanField: Boolean, + binaryField: Array[Byte], array: Seq[Int], map: Map[Int, String], - nested: Nested) + data: Data) class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -116,7 +118,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) TestSQLContext.sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, + (0 to x).map(_.toByte).toArray)) .saveAsParquetFile(tempDir) val result = parquetFile(tempDir).collect() range.foreach { @@ -129,6 +132,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result(i).getShort(5) === i.toShort) assert(result(i).getByte(6) === i.toByte) assert(result(i).getBoolean(7) === (i % 2 == 0)) + assert(result(i)(8) === (0 to i).map(_.toByte).toArray) } } @@ -138,7 +142,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA TestSQLContext.sparkContext.parallelize(range) .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - Seq(x), Map(x -> s"$x"), Nested(x, s"$x"))) + (0 to x).map(_.toByte).toArray, + (0 until x), (0 until x).map(i => i -> s"$i").toMap, Data((0 until x), Nested(x, s"$x")))) .saveAsParquetFile(tempDir) val result = parquetFile(tempDir).collect() range.foreach { @@ -151,9 +156,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result(i).getShort(5) === i.toShort) assert(result(i).getByte(6) === i.toByte) assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === Seq(i)) - assert(result(i)(9) === Map(i -> s"$i")) - assert(result(i)(10) === new GenericRow(Array[Any](i, s"$i"))) + assert(result(i)(8) === (0 to i).map(_.toByte).toArray) + assert(result(i)(9) === (0 until i)) + assert(result(i)(10) === (0 until i).map(i => i -> s"$i").toMap) + assert(result(i)(11) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) } } diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 5ede76e5c3904..f30ae28b81e06 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -31,6 +31,9 @@ jar Spark Project Hive http://spark.apache.org/ + + hive + @@ -48,6 +51,11 @@ hive-metastore ${hive.version} + + commons-httpclient + commons-httpclient + 3.1 + org.spark-project.hive hive-exec diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f83068860701f..8db60d32767b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -43,14 +43,15 @@ import scala.collection.JavaConversions._ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { import HiveMetastoreTypes._ - val client = Hive.get(hive.hiveconf) + /** Connection to hive metastore. Usages should lock on `this`. */ + protected[hive] val client = Hive.get(hive.hiveconf) val caseSensitive: Boolean = false def lookupRelation( db: Option[String], tableName: String, - alias: Option[String]): LogicalPlan = { + alias: Option[String]): LogicalPlan = synchronized { val (dbName, tblName) = processDatabaseAndTableName(db, tableName) val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) val table = client.getTable(databaseName, tblName) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b70104dd5be5a..300e249f5b2e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -860,6 +860,7 @@ private[hive] object HiveQl { val BETWEEN = "(?i)BETWEEN".r val WHEN = "(?i)WHEN".r val CASE = "(?i)CASE".r + val SUBSTR = "(?i)SUBSTR(?:ING)?".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -870,10 +871,7 @@ private[hive] object HiveQl { nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) - // The precidence for . seems to be wrong, so [] binds tighter an we need to go inside to - // find the underlying attribute references. - case GetItem(UnresolvedAttribute(qualifierName), ordinal) => - GetItem(UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)), ordinal) + case other => GetField(other, attr) } /* Stars (*) */ @@ -987,6 +985,10 @@ private[hive] object HiveQl { /* Other functions */ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => + Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => + Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 8cfde46186ca4..c3942578d6b5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -164,13 +164,17 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon hivePartitionRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val rowWithPartArr = new Array[Object](2) + + // The update and deserializer initialization are intentionally + // kept out of the below iter.map loop to save performance. + rowWithPartArr.update(1, partValues) + val deserializer = localDeserializer.newInstance() + deserializer.initialize(hconf, partProps) + // Map each tuple to a row object iter.map { value => - val deserializer = localDeserializer.newInstance() - deserializer.initialize(hconf, partProps) val deserializedRow = deserializer.deserialize(value) rowWithPartArr.update(0, deserializedRow) - rowWithPartArr.update(1, partValues) rowWithPartArr.asInstanceOf[Object] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index ef8bae74530ec..e7016fa16eea9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -96,19 +96,9 @@ case class HiveTableScan( .getOrElse(sys.error(s"Can't find attribute $a")) val fieldObjectInspector = ref.getFieldObjectInspector - val unwrapHiveData = fieldObjectInspector match { - case _: HiveVarcharObjectInspector => - (value: Any) => value.asInstanceOf[HiveVarchar].getValue - case _: HiveDecimalObjectInspector => - (value: Any) => BigDecimal(value.asInstanceOf[HiveDecimal].bigDecimalValue()) - case _ => - identity[Any] _ - } - (row: Any, _: Array[String]) => { val data = objectInspector.getStructFieldData(row, ref) - val hiveData = unwrapData(data, fieldObjectInspector) - if (hiveData != null) unwrapHiveData(hiveData) else null + unwrapData(data, fieldObjectInspector) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 9b105308ab7cf..fc33c5b460d70 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -280,6 +280,10 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression]) private[hive] trait HiveInspectors { def unwrapData(data: Any, oi: ObjectInspector): Any = oi match { + case hvoi: HiveVarcharObjectInspector => + if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue + case hdoi: HiveDecimalObjectInspector => + if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data) case li: ListObjectInspector => Option(li.getList(data)) diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-0-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-0-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-0-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-1-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-1-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-1-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-10-5d712a42dcc1c4f7c797dabda5eb7b3a b/sql/hive/src/test/resources/golden/correlationoptimizer1-10-5d712a42dcc1c4f7c797dabda5eb7b3a new file mode 100644 index 0000000000000..58010b0040b74 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-10-5d712a42dcc1c4f7c797dabda5eb7b3a @@ -0,0 +1 @@ +3556 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-11-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-11-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-11-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-12-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-12-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-12-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-13-f9f839aedb3a350719c0cbc53a06ace5 b/sql/hive/src/test/resources/golden/correlationoptimizer1-13-f9f839aedb3a350719c0cbc53a06ace5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-14-dae4256e08d595317f8e09a56354a3d9 b/sql/hive/src/test/resources/golden/correlationoptimizer1-14-dae4256e08d595317f8e09a56354a3d9 new file mode 100644 index 0000000000000..235736a2807b6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-14-dae4256e08d595317f8e09a56354a3d9 @@ -0,0 +1 @@ +3556 15 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-15-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-15-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-15-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-16-43f356b36962f2bade5706d8cf5ae6b4 b/sql/hive/src/test/resources/golden/correlationoptimizer1-16-43f356b36962f2bade5706d8cf5ae6b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-17-dae4256e08d595317f8e09a56354a3d9 b/sql/hive/src/test/resources/golden/correlationoptimizer1-17-dae4256e08d595317f8e09a56354a3d9 new file mode 100644 index 0000000000000..235736a2807b6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-17-dae4256e08d595317f8e09a56354a3d9 @@ -0,0 +1 @@ +3556 15 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-18-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-18-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-18-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-19-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-19-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-19-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-2-42a0eedc3751f792ad5438b2c64d3897 b/sql/hive/src/test/resources/golden/correlationoptimizer1-2-42a0eedc3751f792ad5438b2c64d3897 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-20-b4c1e27a7c1f61a3e9ae07c80e6e2973 b/sql/hive/src/test/resources/golden/correlationoptimizer1-20-b4c1e27a7c1f61a3e9ae07c80e6e2973 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-21-16c57348be42ca3cc2f80f7f92265696 b/sql/hive/src/test/resources/golden/correlationoptimizer1-21-16c57348be42ca3cc2f80f7f92265696 new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-21-16c57348be42ca3cc2f80f7f92265696 @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-22-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-22-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-22-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-23-2cdc77fd60449f3547cf95d8eb09a2 b/sql/hive/src/test/resources/golden/correlationoptimizer1-23-2cdc77fd60449f3547cf95d8eb09a2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-24-16c57348be42ca3cc2f80f7f92265696 b/sql/hive/src/test/resources/golden/correlationoptimizer1-24-16c57348be42ca3cc2f80f7f92265696 new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-24-16c57348be42ca3cc2f80f7f92265696 @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-25-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-25-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-25-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-26-8bcdcc5f01508f576d7bd6422c939225 b/sql/hive/src/test/resources/golden/correlationoptimizer1-26-8bcdcc5f01508f576d7bd6422c939225 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-27-d31433f229e853e8b8440b4ddc63c80e b/sql/hive/src/test/resources/golden/correlationoptimizer1-27-d31433f229e853e8b8440b4ddc63c80e new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-27-d31433f229e853e8b8440b4ddc63c80e @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-28-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-28-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-28-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-29-941ecfef9448ecff56cc16bcfb233ee4 b/sql/hive/src/test/resources/golden/correlationoptimizer1-29-941ecfef9448ecff56cc16bcfb233ee4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-3-5d712a42dcc1c4f7c797dabda5eb7b3a b/sql/hive/src/test/resources/golden/correlationoptimizer1-3-5d712a42dcc1c4f7c797dabda5eb7b3a new file mode 100644 index 0000000000000..58010b0040b74 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-3-5d712a42dcc1c4f7c797dabda5eb7b3a @@ -0,0 +1 @@ +3556 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-30-d31433f229e853e8b8440b4ddc63c80e b/sql/hive/src/test/resources/golden/correlationoptimizer1-30-d31433f229e853e8b8440b4ddc63c80e new file mode 100644 index 0000000000000..76c4941de407d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-30-d31433f229e853e8b8440b4ddc63c80e @@ -0,0 +1 @@ +3556 47 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-31-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-31-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-31-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-32-ef6502d6b282c8a6d228bba395b24724 b/sql/hive/src/test/resources/golden/correlationoptimizer1-32-ef6502d6b282c8a6d228bba395b24724 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-33-ea87e76dba02a46cb958148333e397b7 b/sql/hive/src/test/resources/golden/correlationoptimizer1-33-ea87e76dba02a46cb958148333e397b7 new file mode 100644 index 0000000000000..5aa2d482094af --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-33-ea87e76dba02a46cb958148333e397b7 @@ -0,0 +1 @@ +79136 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-34-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-34-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-34-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-35-b79b220859c09354e23b533c105ccbab b/sql/hive/src/test/resources/golden/correlationoptimizer1-35-b79b220859c09354e23b533c105ccbab new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-36-ea87e76dba02a46cb958148333e397b7 b/sql/hive/src/test/resources/golden/correlationoptimizer1-36-ea87e76dba02a46cb958148333e397b7 new file mode 100644 index 0000000000000..5aa2d482094af --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-36-ea87e76dba02a46cb958148333e397b7 @@ -0,0 +1 @@ +79136 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-37-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-37-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-37-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-38-638e5300f4c892c2bf27bd91a8f81b64 b/sql/hive/src/test/resources/golden/correlationoptimizer1-38-638e5300f4c892c2bf27bd91a8f81b64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-39-66010469a9cdb66851da9a727ef9fdad b/sql/hive/src/test/resources/golden/correlationoptimizer1-39-66010469a9cdb66851da9a727ef9fdad new file mode 100644 index 0000000000000..b4a3a9d327f47 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-39-66010469a9cdb66851da9a727ef9fdad @@ -0,0 +1 @@ +3556 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-4-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-4-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-4-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-40-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-40-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-40-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-41-3514c74c7f68f2d70cc6d51ac46c20 b/sql/hive/src/test/resources/golden/correlationoptimizer1-41-3514c74c7f68f2d70cc6d51ac46c20 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-42-66010469a9cdb66851da9a727ef9fdad b/sql/hive/src/test/resources/golden/correlationoptimizer1-42-66010469a9cdb66851da9a727ef9fdad new file mode 100644 index 0000000000000..b4a3a9d327f47 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-42-66010469a9cdb66851da9a727ef9fdad @@ -0,0 +1 @@ +3556 500 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-43-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-43-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-43-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-44-7490df6719cd7e47aa08dbcbc3266a92 b/sql/hive/src/test/resources/golden/correlationoptimizer1-44-7490df6719cd7e47aa08dbcbc3266a92 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-45-e71195e7d9f557e2abc7f03462d22dba b/sql/hive/src/test/resources/golden/correlationoptimizer1-45-e71195e7d9f557e2abc7f03462d22dba new file mode 100644 index 0000000000000..bb564e0fd06eb --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-45-e71195e7d9f557e2abc7f03462d22dba @@ -0,0 +1 @@ +3556 510 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-46-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-46-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-46-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-47-73da9fe2b0c2ee26c021ec3f2fa27272 b/sql/hive/src/test/resources/golden/correlationoptimizer1-47-73da9fe2b0c2ee26c021ec3f2fa27272 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-48-e71195e7d9f557e2abc7f03462d22dba b/sql/hive/src/test/resources/golden/correlationoptimizer1-48-e71195e7d9f557e2abc7f03462d22dba new file mode 100644 index 0000000000000..bb564e0fd06eb --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-48-e71195e7d9f557e2abc7f03462d22dba @@ -0,0 +1 @@ +3556 510 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-49-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer1-49-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-49-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-5-a1c80c68b9a7597096c2580c3766f7f7 b/sql/hive/src/test/resources/golden/correlationoptimizer1-5-a1c80c68b9a7597096c2580c3766f7f7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-50-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-50-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-50-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-51-fcf9bcb522f542637ccdea863b408448 b/sql/hive/src/test/resources/golden/correlationoptimizer1-51-fcf9bcb522f542637ccdea863b408448 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-52-3070366869308907e54797927805603 b/sql/hive/src/test/resources/golden/correlationoptimizer1-52-3070366869308907e54797927805603 new file mode 100644 index 0000000000000..edd216f16b190 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-52-3070366869308907e54797927805603 @@ -0,0 +1 @@ +3556 661329102 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-53-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-53-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-53-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-54-dad56e1f06c808b29e5dc8fb0c49efb2 b/sql/hive/src/test/resources/golden/correlationoptimizer1-54-dad56e1f06c808b29e5dc8fb0c49efb2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-55-3070366869308907e54797927805603 b/sql/hive/src/test/resources/golden/correlationoptimizer1-55-3070366869308907e54797927805603 new file mode 100644 index 0000000000000..edd216f16b190 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-55-3070366869308907e54797927805603 @@ -0,0 +1 @@ +3556 661329102 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-56-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer1-56-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-56-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-57-3cd3fbbbd8ee5c274fe3d6a45126cef4 b/sql/hive/src/test/resources/golden/correlationoptimizer1-57-3cd3fbbbd8ee5c274fe3d6a45126cef4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-58-a6bba6d9b422adb386b35c62cecb548 b/sql/hive/src/test/resources/golden/correlationoptimizer1-58-a6bba6d9b422adb386b35c62cecb548 new file mode 100644 index 0000000000000..9fde6099b4f08 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-58-a6bba6d9b422adb386b35c62cecb548 @@ -0,0 +1 @@ +2835 29 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-59-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-59-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-59-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-6-5d712a42dcc1c4f7c797dabda5eb7b3a b/sql/hive/src/test/resources/golden/correlationoptimizer1-6-5d712a42dcc1c4f7c797dabda5eb7b3a new file mode 100644 index 0000000000000..58010b0040b74 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-6-5d712a42dcc1c4f7c797dabda5eb7b3a @@ -0,0 +1 @@ +3556 37 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-60-d6bbaf0d40010159095e4cac025c50c5 b/sql/hive/src/test/resources/golden/correlationoptimizer1-60-d6bbaf0d40010159095e4cac025c50c5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-61-a6bba6d9b422adb386b35c62cecb548 b/sql/hive/src/test/resources/golden/correlationoptimizer1-61-a6bba6d9b422adb386b35c62cecb548 new file mode 100644 index 0000000000000..9fde6099b4f08 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-61-a6bba6d9b422adb386b35c62cecb548 @@ -0,0 +1 @@ +2835 29 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-7-24ca942f094b14b92086305cc125e833 b/sql/hive/src/test/resources/golden/correlationoptimizer1-7-24ca942f094b14b92086305cc125e833 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-7-24ca942f094b14b92086305cc125e833 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-8-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer1-8-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer1-8-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer1-9-d5bea91b4edb8be0428a336ff9c21dde b/sql/hive/src/test/resources/golden/correlationoptimizer1-9-d5bea91b4edb8be0428a336ff9c21dde new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-0-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/correlationoptimizer10-0-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-0-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-1-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer10-1-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-1-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-10-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer10-10-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-10-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-11-a1b7af95dfd01783c07aa23208d6160 b/sql/hive/src/test/resources/golden/correlationoptimizer10-11-a1b7af95dfd01783c07aa23208d6160 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-12-1322cff0bdf29aab32e638ad48c71ff9 b/sql/hive/src/test/resources/golden/correlationoptimizer10-12-1322cff0bdf29aab32e638ad48c71ff9 new file mode 100644 index 0000000000000..9ea431d9f5d18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-12-1322cff0bdf29aab32e638ad48c71ff9 @@ -0,0 +1,5 @@ +66 val_66 +98 val_98 +128 +146 val_146 +150 val_150 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-13-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer10-13-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-13-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-14-934b668c11600dc9c013c2ddc4c0d68c b/sql/hive/src/test/resources/golden/correlationoptimizer10-14-934b668c11600dc9c013c2ddc4c0d68c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-15-430ff20a144fb3dbf526232d9cb2baa b/sql/hive/src/test/resources/golden/correlationoptimizer10-15-430ff20a144fb3dbf526232d9cb2baa new file mode 100644 index 0000000000000..5abecc5df25fd --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-15-430ff20a144fb3dbf526232d9cb2baa @@ -0,0 +1,23 @@ +181 val_181 +183 val_183 +186 val_186 +187 val_187 +187 val_187 +187 val_187 +189 val_189 +190 val_190 +191 val_191 +191 val_191 +192 val_192 +193 val_193 +193 val_193 +193 val_193 +194 val_194 +195 val_195 +195 val_195 +196 val_196 +197 val_197 +197 val_197 +199 val_199 +199 val_199 +199 val_199 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-16-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer10-16-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-16-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-17-9d5521ecef1353d23fd2d4f7d78e7006 b/sql/hive/src/test/resources/golden/correlationoptimizer10-17-9d5521ecef1353d23fd2d4f7d78e7006 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-18-430ff20a144fb3dbf526232d9cb2baa b/sql/hive/src/test/resources/golden/correlationoptimizer10-18-430ff20a144fb3dbf526232d9cb2baa new file mode 100644 index 0000000000000..5abecc5df25fd --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-18-430ff20a144fb3dbf526232d9cb2baa @@ -0,0 +1,23 @@ +181 val_181 +183 val_183 +186 val_186 +187 val_187 +187 val_187 +187 val_187 +189 val_189 +190 val_190 +191 val_191 +191 val_191 +192 val_192 +193 val_193 +193 val_193 +193 val_193 +194 val_194 +195 val_195 +195 val_195 +196 val_196 +197 val_197 +197 val_197 +199 val_199 +199 val_199 +199 val_199 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-2-f9de06a4184ab1f42793327c1497437 b/sql/hive/src/test/resources/golden/correlationoptimizer10-2-f9de06a4184ab1f42793327c1497437 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-3-6a01aa7ca94cda4268af894b4fd852ea b/sql/hive/src/test/resources/golden/correlationoptimizer10-3-6a01aa7ca94cda4268af894b4fd852ea new file mode 100644 index 0000000000000..d00aeb4be0340 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-3-6a01aa7ca94cda4268af894b4fd852ea @@ -0,0 +1,15 @@ +66 1 +98 1 +128 1 +146 1 +150 1 +213 1 +224 1 +238 1 +255 1 +273 1 +278 1 +311 1 +369 1 +401 1 +406 1 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-4-777edd9d575f3480ca6cebe4be57b1f6 b/sql/hive/src/test/resources/golden/correlationoptimizer10-4-777edd9d575f3480ca6cebe4be57b1f6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-4-777edd9d575f3480ca6cebe4be57b1f6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-5-6f191930802d659058465b2e6de08dd3 b/sql/hive/src/test/resources/golden/correlationoptimizer10-5-6f191930802d659058465b2e6de08dd3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-6-6a01aa7ca94cda4268af894b4fd852ea b/sql/hive/src/test/resources/golden/correlationoptimizer10-6-6a01aa7ca94cda4268af894b4fd852ea new file mode 100644 index 0000000000000..d00aeb4be0340 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-6-6a01aa7ca94cda4268af894b4fd852ea @@ -0,0 +1,15 @@ +66 1 +98 1 +128 1 +146 1 +150 1 +213 1 +224 1 +238 1 +255 1 +273 1 +278 1 +311 1 +369 1 +401 1 +406 1 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-7-b9d963d24994c47c3776dda6f7d3881f b/sql/hive/src/test/resources/golden/correlationoptimizer10-7-b9d963d24994c47c3776dda6f7d3881f new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-7-b9d963d24994c47c3776dda6f7d3881f @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-8-b2c8d0056b9f1b41cdaeaab4a1a5c9ec b/sql/hive/src/test/resources/golden/correlationoptimizer10-8-b2c8d0056b9f1b41cdaeaab4a1a5c9ec new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/correlationoptimizer10-9-1322cff0bdf29aab32e638ad48c71ff9 b/sql/hive/src/test/resources/golden/correlationoptimizer10-9-1322cff0bdf29aab32e638ad48c71ff9 new file mode 100644 index 0000000000000..9ea431d9f5d18 --- /dev/null +++ b/sql/hive/src/test/resources/golden/correlationoptimizer10-9-1322cff0bdf29aab32e638ad48c71ff9 @@ -0,0 +1,5 @@ +66 val_66 +98 val_98 +128 +146 val_146 +150 val_150 diff --git a/sql/hive/src/test/resources/golden/groupby_ppd-0-4b116ec2d8fb55c52b7b0d248c616ae2 b/sql/hive/src/test/resources/golden/groupby_ppd-0-4b116ec2d8fb55c52b7b0d248c616ae2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/groupby_ppd-1-db7b1db8f5e61f0fa78d2874e4d72d9d b/sql/hive/src/test/resources/golden/groupby_ppd-1-db7b1db8f5e61f0fa78d2874e4d72d9d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/groupby_ppd-2-d286410aa1d5f5c8d91b863a6d6e29c5 b/sql/hive/src/test/resources/golden/groupby_ppd-2-d286410aa1d5f5c8d91b863a6d6e29c5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-0-79b294d0081c3dfd36c5b8b5e78dc7fb b/sql/hive/src/test/resources/golden/limit_pushdown_negative-0-79b294d0081c3dfd36c5b8b5e78dc7fb new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/limit_pushdown_negative-0-79b294d0081c3dfd36c5b8b5e78dc7fb @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-1-f87339637a48bd1533493ebbed5432a7 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-1-f87339637a48bd1533493ebbed5432a7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-2-de7e5ac581b870fff10dc82c75c1c79e b/sql/hive/src/test/resources/golden/limit_pushdown_negative-2-de7e5ac581b870fff10dc82c75c1c79e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-3-be440c3f959ca53b758481aa90551984 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-3-be440c3f959ca53b758481aa90551984 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-4-4dedc8057d76af264c198beaacd7f000 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-4-4dedc8057d76af264c198beaacd7f000 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-5-543a20e69bd8987bc37a22c1c7ef33f1 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-5-543a20e69bd8987bc37a22c1c7ef33f1 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-6-3f8274466914ad200b33a2c83fa6dab5 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-6-3f8274466914ad200b33a2c83fa6dab5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/limit_pushdown_negative-7-fb7bf3783d4fb43673a202c4111d9092 b/sql/hive/src/test/resources/golden/limit_pushdown_negative-7-fb7bf3783d4fb43673a202c4111d9092 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_3-10-7b1ec929239ee305ea9da46ebb990c67 b/sql/hive/src/test/resources/golden/timestamp_3-10-7b1ec929239ee305ea9da46ebb990c67 new file mode 100644 index 0000000000000..1b0a140b5a384 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-10-7b1ec929239ee305ea9da46ebb990c67 @@ -0,0 +1 @@ +1.3041352164485E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-11-a63f40f6c4a022c16f8cf810e3b7ed2a b/sql/hive/src/test/resources/golden/timestamp_3-11-a63f40f6c4a022c16f8cf810e3b7ed2a new file mode 100644 index 0000000000000..d7ff6cd63d9f0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-11-a63f40f6c4a022c16f8cf810e3b7ed2a @@ -0,0 +1 @@ +2011-04-29 20:46:56.4485 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-12-165256158e3db1ce19c3c9db3c8011d2 b/sql/hive/src/test/resources/golden/timestamp_3-12-165256158e3db1ce19c3c9db3c8011d2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_3-3-6143888a940bfcac1133330764f5a31a b/sql/hive/src/test/resources/golden/timestamp_3-3-6143888a940bfcac1133330764f5a31a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_3-4-935d0d2492beab99bbbba26ba62a1db4 b/sql/hive/src/test/resources/golden/timestamp_3-4-935d0d2492beab99bbbba26ba62a1db4 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-4-935d0d2492beab99bbbba26ba62a1db4 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_3-5-8fe348d5d9b9903a26eda32d308b8e41 b/sql/hive/src/test/resources/golden/timestamp_3-5-8fe348d5d9b9903a26eda32d308b8e41 new file mode 100644 index 0000000000000..21e72e8ac3d7e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-5-8fe348d5d9b9903a26eda32d308b8e41 @@ -0,0 +1 @@ +48 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-6-6be5fe01c502cd24db32a3781c97a703 b/sql/hive/src/test/resources/golden/timestamp_3-6-6be5fe01c502cd24db32a3781c97a703 new file mode 100644 index 0000000000000..ee3be2941da6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-6-6be5fe01c502cd24db32a3781c97a703 @@ -0,0 +1 @@ +-31184 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-7-6066ba0451cd0fcfac4bea6376e72add b/sql/hive/src/test/resources/golden/timestamp_3-7-6066ba0451cd0fcfac4bea6376e72add new file mode 100644 index 0000000000000..1cf1952ac0372 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-7-6066ba0451cd0fcfac4bea6376e72add @@ -0,0 +1 @@ +1304135216 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-8-22e03daa775eab145d39ec0730953f7e b/sql/hive/src/test/resources/golden/timestamp_3-8-22e03daa775eab145d39ec0730953f7e new file mode 100644 index 0000000000000..1cf1952ac0372 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-8-22e03daa775eab145d39ec0730953f7e @@ -0,0 +1 @@ +1304135216 diff --git a/sql/hive/src/test/resources/golden/timestamp_3-9-ffc79abb874323e165963aa39f460a9b b/sql/hive/src/test/resources/golden/timestamp_3-9-ffc79abb874323e165963aa39f460a9b new file mode 100644 index 0000000000000..d21deca762237 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_3-9-ffc79abb874323e165963aa39f460a9b @@ -0,0 +1 @@ +1.30413517E9 diff --git a/sql/hive/src/test/resources/golden/timestamp_null-3-222c5ea127c747c71738b5dc5b80459c b/sql/hive/src/test/resources/golden/timestamp_null-3-222c5ea127c747c71738b5dc5b80459c new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_null-3-222c5ea127c747c71738b5dc5b80459c @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/timestamp_null-4-ffc86f5c714eceabc36e92931b96beb0 b/sql/hive/src/test/resources/golden/timestamp_null-4-ffc86f5c714eceabc36e92931b96beb0 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_null-4-ffc86f5c714eceabc36e92931b96beb0 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index cdfc2d0c17384..63dbe57c4c772 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -84,6 +84,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_java_method", "create_merge_compressed", + // DFS commands + "symlink_text_input_format", + // Weird DDL differences result in failures on jenkins. "create_like2", "create_view_translate", @@ -278,7 +281,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "compute_stats_string", "compute_stats_table", "convert_enum_to_string", + "correlationoptimizer1", + "correlationoptimizer10", "correlationoptimizer11", + "correlationoptimizer14", "correlationoptimizer15", "correlationoptimizer2", "correlationoptimizer3", @@ -296,6 +302,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ct_case_insensitive", "database_location", "database_properties", + "decimal_1", "decimal_4", "decimal_join", "default_partition_name", @@ -304,6 +311,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "describe_formatted_view_partitioned", "diff_part_input_formats", "disable_file_format_check", + "disallow_incompatible_type_change_off", "drop_function", "drop_index", "drop_multi_partitions", @@ -359,8 +367,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_map_ppr", "groupby_multi_insert_common_distinct", "groupby_multi_single_reducer2", + "groupby_multi_single_reducer3", "groupby_mutli_insert_common_distinct", "groupby_neg_float", + "groupby_ppd", "groupby_ppr", "groupby_sort_10", "groupby_sort_2", @@ -400,6 +410,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "input4", "input40", "input41", + "input49", "input4_cb_delim", "input6", "input7", @@ -491,6 +502,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "lateral_view_ppd", "leftsemijoin", "leftsemijoin_mr", + "limit_pushdown_negative", "lineage1", "literal_double", "literal_ints", @@ -598,6 +610,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "reduce_deduplicate", "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", + "reduce_deduplicate_extended", "reducesink_dedup", "rename_column", "router_join_ppr", @@ -646,7 +659,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_publisher_error_1", "subq2", "tablename_with_select", + "timestamp_3", "timestamp_comparison", + "timestamp_null", + "timestamp_udf", "touch", "transform_ppr1", "transform_ppr2", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a623d29b53973..d57e99db1858f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -390,7 +390,7 @@ class HiveQuerySuite extends HiveComparisonTest { hql("CREATE TABLE m(value MAP)") hql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") hql("SELECT * FROM m").collect().zip(hql("SELECT * FROM src LIMIT 10").collect()).map { - case (Row(map: Map[Int, String]), Row(key: Int, value: String)) => + case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 67594b57d3dfa..fb03db12a0b01 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -case class Data(a: Int, B: Int, n: Nested) +case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) case class Nested(a: Int, B: Int) /** @@ -53,12 +53,18 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil) + TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) .registerAsTable("caseSensitivityTest") hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } + test("nested repeated resolution") { + TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + .registerAsTable("nestedRepeatedTest") + assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) + } + /** * Negative examples. Currently only left here for documentation purposes. * TODO(marmbrus): Test that catalyst fails on these queries. diff --git a/streaming/pom.xml b/streaming/pom.xml index f506d6ce34a6f..f60697ce745b7 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -27,6 +27,9 @@ org.apache.spark spark-streaming_2.10 + + streaming + jar Spark Project Streaming http://spark.apache.org/ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index c4bdf01fa3744..c00e11d11910f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -737,7 +737,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } object JavaPairDStream { - implicit def fromPairDStream[K: ClassTag, V: ClassTag](dstream: DStream[(K, V)]) = { + implicit def fromPairDStream[K: ClassTag, V: ClassTag](dstream: DStream[(K, V)]) + : JavaPairDStream[K, V] = { new JavaPairDStream[K, V](dstream) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index 6376cff78b78a..ed7da6dc1315e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -41,7 +41,7 @@ class QueueInputDStream[T: ClassTag]( if (oneAtATime && queue.size > 0) { buffer += queue.dequeue() } else { - buffer ++= queue + buffer ++= queue.dequeueAll(_ => true) } if (buffer.size > 0) { if (oneAtATime) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 78cc2daa56e53..0316b6862f195 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -44,7 +44,7 @@ private[streaming] class BlockGenerator( listener: BlockGeneratorListener, receiverId: Int, conf: SparkConf - ) extends Logging { + ) extends RateLimiter(conf) with Logging { private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) @@ -81,6 +81,7 @@ private[streaming] class BlockGenerator( * will be periodically pushed into BlockManager. */ def += (data: Any): Unit = synchronized { + waitToPush() currentBuffer += data } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala new file mode 100644 index 0000000000000..e4f6ba626ebbf --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import org.apache.spark.{Logging, SparkConf} +import java.util.concurrent.TimeUnit._ + +/** Provides waitToPush() method to limit the rate at which receivers consume data. + * + * waitToPush method will block the thread if too many messages have been pushed too quickly, + * and only return when a new message has been pushed. It assumes that only one message is + * pushed at a time. + * + * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages + * per second that each receiver will accept. + * + * @param conf spark configuration + */ +private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { + + private var lastSyncTime = System.nanoTime + private var messagesWrittenSinceSync = 0L + private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) + private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + + def waitToPush() { + if( desiredRate <= 0 ) { + return + } + val now = System.nanoTime + val elapsedNanosecs = math.max(now - lastSyncTime, 1) + val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs + if (rate < desiredRate) { + // It's okay to write; just update some variables and return + messagesWrittenSinceSync += 1 + if (now > lastSyncTime + SYNC_INTERVAL) { + // Sync interval has passed; let's resync + lastSyncTime = now + messagesWrittenSinceSync = 1 + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate + val elapsedTimeInMillis = elapsedNanosecs / 1000000 + val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis + if (sleepTimeInMillis > 0) { + logTrace("Natural rate is " + rate + " per second but desired rate is " + + desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") + Thread.sleep(sleepTimeInMillis) + } + waitToPush() + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index cd0aa4d0dce70..952a74fd5f6de 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -29,7 +29,7 @@ import java.nio.charset.Charset import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} import com.google.common.io.Files import org.scalatest.BeforeAndAfter @@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.util.ManualClock import org.apache.spark.util.Utils import org.apache.spark.streaming.receiver.{ActorHelper, Receiver} +import org.apache.spark.rdd.RDD class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -234,6 +235,95 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { logInfo("--------------------------------") assert(output.sum === numTotalRecords) } + + test("queue input stream - oneAtATime=true") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = true) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(queueStream, outputBuffer) + def output = outputBuffer.filter(_.size > 0) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq("1", "2", "3", "4", "5") + val expectedOutput = input.map(Seq(_)) + //Thread.sleep(1000) + val inputIterator = input.toIterator + for (i <- 0 until input.size) { + // Enqueue more than 1 item per tick but they should dequeue one at a time + inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.addToTime(batchDuration.milliseconds) + } + Thread.sleep(1000) + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } + + test("queue input stream - oneAtATime=false") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val queue = new SynchronizedQueue[RDD[String]]() + val queueStream = ssc.queueStream(queue, oneAtATime = false) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(queueStream, outputBuffer) + def output = outputBuffer.filter(_.size > 0) + outputStream.register() + ssc.start() + + // Setup data queued into the stream + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq("1", "2", "3", "4", "5") + val expectedOutput = Seq(Seq("1", "2", "3"), Seq("4", "5")) + + // Enqueue the first 3 items (one by one), they should be merged in the next batch + val inputIterator = input.toIterator + inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(1000) + + // Enqueue the remaining items (again one by one), merged in the final batch + inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i))) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(1000) + logInfo("Stopping context") + ssc.stop() + + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) + } + } } @@ -293,7 +383,10 @@ class TestActor(port: Int) extends Actor with ActorHelper { def bytesToString(byteString: ByteString) = byteString.utf8String - override def preStart = IOManager(context.system).connect(new InetSocketAddress(port)) + override def preStart(): Unit = { + @deprecated("suppress compile time deprecation warning", "1.0.0") + val unit = IOManager(context.system).connect(new InetSocketAddress(port)) + } def receive = { case IO.Read(socket, bytes) => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index d9ac3c91f6e36..f4e11f975de94 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -145,6 +145,44 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { assert(recordedData.toSet === generatedData.toSet) } + test("block generator throttling") { + val blockGeneratorListener = new FakeBlockGeneratorListener + val blockInterval = 50 + val maxRate = 200 + val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString). + set("spark.streaming.receiver.maxRate", maxRate.toString) + val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) + val expectedBlocks = 20 + val waitTime = expectedBlocks * blockInterval + val expectedMessages = maxRate * waitTime / 1000 + val expectedMessagesPerBlock = maxRate * blockInterval / 1000 + val generatedData = new ArrayBuffer[Int] + + // Generate blocks + val startTime = System.currentTimeMillis() + blockGenerator.start() + var count = 0 + while(System.currentTimeMillis - startTime < waitTime) { + blockGenerator += count + generatedData += count + count += 1 + Thread.sleep(1) + } + blockGenerator.stop() + + val recordedData = blockGeneratorListener.arrayBuffers + assert(blockGeneratorListener.arrayBuffers.size > 0) + assert(recordedData.flatten.toSet === generatedData.toSet) + // recordedData size should be close to the expected rate + assert(recordedData.flatten.size >= expectedMessages * 0.9 && + recordedData.flatten.size <= expectedMessages * 1.1 ) + // the first and last block may be incomplete, so we slice them out + recordedData.slice(1, recordedData.size - 1).foreach { block => + assert(block.size >= expectedMessagesPerBlock * 0.8 && + block.size <= expectedMessagesPerBlock * 1.2 ) + } + } + /** * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. */ diff --git a/tools/pom.xml b/tools/pom.xml index 79cd8551d0722..c0ee8faa7a615 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -26,6 +26,9 @@ org.apache.spark spark-tools_2.10 + + tools + jar Spark Project Tools http://spark.apache.org/ diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index b8a631dd0bb3b..5b13a1f002d6e 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -23,6 +23,9 @@ 1.1.0-SNAPSHOT ../pom.xml + + yarn-alpha + org.apache.spark spark-yarn-alpha_2.10 diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 438737f7a6b60..062f946a9fe93 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -184,6 +184,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private def startUserClass(): Thread = { logInfo("Starting the user JAR in a separate Thread") + System.setProperty("spark.executor.instances", args.numExecutors.toString) val mainMethod = Class.forName( args.userClass, false /* initialize */ , diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index bfdb6232f5113..a86ad256dfa39 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -32,6 +32,7 @@ import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter import org.apache.spark.scheduler.SplitInfo import org.apache.spark.deploy.SparkHadoopUtil @@ -81,6 +82,9 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") driverClosed = true + case x: AddWebUIFilter => + logInfo(s"Add WebUI Filter. $x") + driver ! x } } @@ -111,7 +115,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } waitForSparkMaster() - + addAmIpFilter() // Allocate all containers allocateExecutors() @@ -171,7 +175,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - logInfo("Registering the ApplicationMaster") + val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") + logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) .asInstanceOf[RegisterApplicationMasterRequest] appMasterRequest.setApplicationAttemptId(appAttemptId) @@ -180,10 +185,21 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp appMasterRequest.setHost(Utils.localHostName()) appMasterRequest.setRpcPort(0) // What do we provide here ? Might make sense to expose something sensible later ? - appMasterRequest.setTrackingUrl("") + appMasterRequest.setTrackingUrl(appUIAddress) resourceManager.registerApplicationMaster(appMasterRequest) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val proxy = YarnConfiguration.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + val uriBase = "http://" + proxy + proxyBase + val amFilter = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + val amFilterName = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + actor ! AddWebUIFilter(amFilterName, amFilter, proxyBase) + } + private def waitForSparkMaster() { logInfo("Waiting for spark driver to be reachable.") var driverUp = false diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 25cc9016b10a6..4c383ab574abe 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -26,7 +26,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var userArgs: Seq[String] = Seq[String]() var executorMemory = 1024 var executorCores = 1 - var numExecutors = 2 + var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS parseArgs(args.toList) @@ -93,3 +93,7 @@ class ApplicationMasterArguments(val args: Array[String]) { System.exit(exitCode) } } + +object ApplicationMasterArguments { + val DEFAULT_NUMBER_EXECUTORS = 2 +} diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 6b91e6b9eb899..15e8c21aa5906 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -40,8 +40,10 @@ private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configur override def postStartHook() { + super.postStartHook() // The yarn application is running, but the executor might not yet ready // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + // TODO It needn't after waitBackendReady Thread.sleep(2000L) logInfo("YarnClientClusterScheduler.postStartHook done") } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index fd2694fe7278d..1b37c4bb13f49 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -48,6 +48,7 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort + conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( @@ -75,6 +76,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf) val args = new ClientArguments(argsArrayBuf.toArray, conf) + totalExpectedExecutors.set(args.numExecutors) client = new Client(args, conf) appId = client.runApp() waitForApp() diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index 39cdd2e8a522b..9ee53d797c8ea 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -48,9 +48,11 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) override def postStartHook() { val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc) + super.postStartHook() if (sparkContextInitialized){ ApplicationMaster.waitForInitialAllocations() // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + // TODO It needn't after waitBackendReady Thread.sleep(3000L) } logInfo("YarnClusterScheduler.postStartHook done") diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala new file mode 100644 index 0000000000000..a04b08f43cc5a --- /dev/null +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark.SparkContext +import org.apache.spark.deploy.yarn.ApplicationMasterArguments +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.util.IntParam + +private[spark] class YarnClusterSchedulerBackend( + scheduler: TaskSchedulerImpl, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + + override def start() { + super.start() + var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { + numExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")).getOrElse(numExecutors) + } + // System property can override environment variable. + numExecutors = sc.getConf.getInt("spark.executor.instances", numExecutors) + totalExpectedExecutors.set(numExecutors) + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index ef7066ef1fdfc..efb473aa1b261 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -28,6 +28,9 @@ yarn-parent_2.10 pom Spark Project YARN Parent POM + + yarn + diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 0931beb505508..ceaf9f9d71001 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -23,6 +23,9 @@ 1.1.0-SNAPSHOT ../pom.xml + + yarn-stable + org.apache.spark spark-yarn_2.10 diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index ee1e9c9c23d22..1a24ec759b546 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -164,6 +164,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private def startUserClass(): Thread = { logInfo("Starting the user JAR in a separate Thread") + System.setProperty("spark.executor.instances", args.numExecutors.toString) val mainMethod = Class.forName( args.userClass, false, diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index f71ad036ce0f2..5ac95f3798723 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -31,10 +31,12 @@ import akka.actor.Terminated import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter import org.apache.spark.scheduler.SplitInfo import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.yarn.webapp.util.WebAppUtils /** * An application master that allocates executors on behalf of a driver that is running outside @@ -82,6 +84,9 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") driverClosed = true + case x: AddWebUIFilter => + logInfo(s"Add WebUI Filter. $x") + driver ! x } } @@ -99,6 +104,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp registerApplicationMaster() waitForSparkMaster() + addAmIpFilter() // Allocate all containers allocateExecutors() @@ -142,9 +148,20 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } private def registerApplicationMaster(): RegisterApplicationMasterResponse = { - logInfo("Registering the ApplicationMaster") - // TODO: Find out client's Spark UI address and fill in here? - amClient.registerApplicationMaster(Utils.localHostName(), 0, "") + val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") + logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") + amClient.registerApplicationMaster(Utils.localHostName(), 0, appUIAddress) + } + + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val proxy = WebAppUtils.getProxyHostAndPort(conf) + val parts = proxy.split(":") + val proxyBase = System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + val uriBase = "http://" + proxy + proxyBase + val amFilter = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + val amFilterName = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + actor ! AddWebUIFilter(amFilterName, amFilter, proxyBase) } private def waitForSparkMaster() {
Property NameDefaultMeaning
spark.deploy.retainedApplications200 + The maximum number of completed applications to display. Older applications will be dropped from the UI to maintain this limit.
+
spark.deploy.retainedDrivers200 + The maximum number of completed drivers to display. Older drivers will be dropped from the UI to maintain this limit.
+
spark.deploy.spreadOut true