diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000..c6b4aa5344757 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,12 @@ +## Contributing to Spark + +Contributions via GitHub pull requests are gladly accepted from their original +author. Along with any pull requests, please state that the contribution is +your original work and that you license the work to the project under the +project's open source license. Whether or not you state this explicitly, by +submitting any copyrighted material via pull request, email, or other means +you agree to license the material under the project's open source license and +warrant that you have the legal authority to do so. + +Please see [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) +for more information. diff --git a/README.md b/README.md index 5b09ad86849e7..8dd8b70696aa2 100644 --- a/README.md +++ b/README.md @@ -13,16 +13,19 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the project webpage at . +guide, on the [project web page](http://spark.apache.org/documentation.html). This README file only contains basic setup instructions. ## Building Spark -Spark is built on Scala 2.10. To build Spark and its example programs, run: +Spark is built using [Apache Maven](http://maven.apache.org/). +To build Spark and its example programs, run: - ./sbt/sbt assembly + mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) +More detailed documentation is available from the project site, at +["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). ## Interactive Scala Shell @@ -71,73 +74,24 @@ can be run using: ./dev/run-tests +Please see the guidance on how to +[run all automated tests](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-AutomatedTesting). + ## A Note About Hadoop Versions Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. -You can change the version by setting `-Dhadoop.version` when building Spark. - -For Apache Hadoop versions 1.x, Cloudera CDH MRv1, and other Hadoop -versions without YARN, use: - - # Apache Hadoop 1.2.1 - $ sbt/sbt -Dhadoop.version=1.2.1 assembly - - # Cloudera CDH 4.2.0 with MapReduce v1 - $ sbt/sbt -Dhadoop.version=2.0.0-mr1-cdh4.2.0 assembly - -For Apache Hadoop 2.2.X, 2.1.X, 2.0.X, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions -with YARN, also set `-Pyarn`: - - # Apache Hadoop 2.0.5-alpha - $ sbt/sbt -Dhadoop.version=2.0.5-alpha -Pyarn assembly - - # Cloudera CDH 4.2.0 with MapReduce v2 - $ sbt/sbt -Dhadoop.version=2.0.0-cdh4.2.0 -Pyarn assembly - - # Apache Hadoop 2.2.X and newer - $ sbt/sbt -Dhadoop.version=2.2.0 -Pyarn assembly - -When developing a Spark application, specify the Hadoop version by adding the -"hadoop-client" artifact to your project's dependencies. For example, if you're -using Hadoop 1.2.1 and build your application using SBT, add this entry to -`libraryDependencies`: - - "org.apache.hadoop" % "hadoop-client" % "1.2.1" -If your project is built with Maven, add this to your POM file's `` section: - - - org.apache.hadoop - hadoop-client - 1.2.1 - - - -## A Note About Thrift JDBC server and CLI for Spark SQL - -Spark SQL supports Thrift JDBC server and CLI. -See sql-programming-guide.md for more information about using the JDBC server and CLI. -You can use those features by setting `-Phive` when building Spark as follows. - - $ sbt/sbt -Phive assembly +Please refer to the build documentation at +["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) +for detailed guidance on building for a particular distribution of Hadoop, including +building for particular Hive and Hive Thriftserver distributions. See also +["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) +for guidance on building a Spark application that works with a particular +distribution. ## Configuration Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. - - -## Contributing to Spark - -Contributions via GitHub pull requests are gladly accepted from their original -author. Along with any pull requests, please state that the contribution is -your original work and that you license the work to the project under the -project's open source license. Whether or not you state this explicitly, by -submitting any copyrighted material via pull request, email, or other means -you agree to license the material under the project's open source license and -warrant that you have the legal authority to do so. - -Please see [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) -for more information. diff --git a/assembly/pom.xml b/assembly/pom.xml index 4146168fc804b..604b1ab3de6a8 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -88,6 +88,20 @@ + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index 55241d33cd3f0..ccb262a4ee02a 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -24,8 +24,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.storage.StorageLevel -import scala.language.postfixOps - class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 15c6779402994..0f63e36d8aeca 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -43,6 +43,7 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" diff --git a/bin/spark-class b/bin/spark-class index 5f5f9ea74888d..613dc9c4566f2 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -105,7 +105,7 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size if [ "$JAVA_VERSION" -ge 18 ]; then diff --git a/core/pom.xml b/core/pom.xml index b2b788a4bc13b..2a81f6df289c0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -351,6 +351,33 @@ + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies + package + + copy-dependencies + + + ${project.build.directory} + false + false + true + true + guava + true + + + + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 12b15fe0815be..3832a780ec4bc 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -162,7 +162,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { // always add the current user and SPARK_USER to the viewAcls private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), - Option(System.getenv("SPARK_USER")).getOrElse("")) + Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c6c5b8f22b549..428f019b02a23 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -220,8 +220,14 @@ class SparkContext(config: SparkConf) extends Logging { new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) // Initialize the Spark UI, registering all associated listeners - private[spark] val ui = new SparkUI(this) - ui.bind() + private[spark] val ui: Option[SparkUI] = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(new SparkUI(this)) + } else { + // For tests, do not enable the UI + None + } + ui.foreach(_.bind()) /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) @@ -990,7 +996,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Shut down the SparkContext. */ def stop() { postApplicationEnd() - ui.stop() + ui.foreach(_.stop()) // Do this only if not stopped already - best case effort. // prevent NPE if stopped more than once. val dagSchedulerCopy = dagScheduler @@ -1066,11 +1072,8 @@ class SparkContext(config: SparkConf) extends Logging { val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite.shortForm) - val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) - logInfo( - "Job finished: " + callSite.shortForm + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index dd95e406f2a8e..009ed64775844 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -108,6 +108,14 @@ class SparkEnv ( pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } + + private[spark] + def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.releaseWorker(worker)) + } + } } object SparkEnv extends Logging { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 2b99b8a5af250..51b3e4d5e0936 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} /** @@ -41,7 +41,7 @@ class TaskContext( val attemptId: Long, val runningLocally: Boolean = false, private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable { + extends Serializable with Logging { @deprecated("use partitionId", "0.8.1") def splitId = partitionId @@ -103,8 +103,20 @@ class TaskContext( /** Marks the task as completed and triggers the listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true + val errorMsgs = new ArrayBuffer[String](2) // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) } + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } } /** Marks the task for interruption, i.e. cancellation. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8e178bc8480f7..791d853a015a1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.java +import java.io.Closeable import java.util import java.util.{Map => JMap} @@ -40,7 +41,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. */ -class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { +class JavaSparkContext(val sc: SparkContext) + extends JavaSparkContextVarargsWorkaround with Closeable { + /** * Create a JavaSparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -534,6 +537,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork sc.stop() } + override def close(): Unit = stop() + /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ae8010300a500..12b345a8fa7c3 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,6 +23,7 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Try, Success, Failure} @@ -52,6 +53,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions = parent.partitions @@ -63,19 +65,26 @@ private[spark] class PythonRDD( val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + if (reuse_worker) { + envVars += ("SPARK_REUSE_WORKER" -> "1") + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) + var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - - // Cleanup the worker socket. This will also cause the Python worker to exit. - try { - worker.close() - } catch { - case e: Exception => logWarning("Failed to close worker socket", e) + if (reuse_worker && complete_cleanly) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + } else { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } } } @@ -115,6 +124,10 @@ private[spark] class PythonRDD( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += diskBytesSpilled read() case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python @@ -133,6 +146,7 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + complete_cleanly = true null } } catch { @@ -195,11 +209,26 @@ private[spark] class PythonRDD( PythonRDD.writeUTF(include, dataOut) } // Broadcast variables - dataOut.writeInt(broadcastVars.length) + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- oldBids) { + if (!newBids.contains(bid)) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + } for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + oldBids.add(broadcast.id) + } } dataOut.flush() // Serialized command: @@ -207,17 +236,18 @@ private[spark] class PythonRDD( dataOut.write(command) // Data values PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) + worker.shutdownOutput() case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - } finally { - Try(worker.shutdownOutput()) // kill Python worker process + worker.shutdownOutput() } } } @@ -278,6 +308,14 @@ private object SpecialLengths { private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") + // remember the broadcasts sent to each worker + private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + private def getWorkerBroadcasts(worker: Socket) = { + synchronized { + workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) + } + } + /** * Adapter for calling SparkContext#runJob from Python. * @@ -738,7 +776,7 @@ private[spark] object PythonRDD extends Logging { } /** - * Convert and RDD of Java objects to and RDD of serialized Python objects, that is usable by + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark. */ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4c4796f6c59ba..71bdf0fe1b917 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 - var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val idleWorkers = new mutable.Queue[Socket]() + var lastActivity = 0L + new MonitorThread().start() var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() @@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { if (useDaemon) { + synchronized { + if (idleWorkers.size > 0) { + return idleWorkers.dequeue() + } + } createThroughDaemon() } else { createSimpleWorker() @@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Monitor all the idle workers, kill them after timeout. + */ + private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + while (true) { + synchronized { + if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) { + cleanupIdleWorkers() + lastActivity = System.currentTimeMillis() + } + } + Thread.sleep(10000) + } + } + } + + private def cleanupIdleWorkers() { + while (idleWorkers.length > 0) { + val worker = idleWorkers.dequeue() + try { + // the worker will exit after closing the socket + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + private def stopDaemon() { synchronized { if (useDaemon) { + cleanupIdleWorkers() + // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { daemon.destroy() @@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def stopWorker(worker: Socket) { - if (useDaemon) { - if (daemon != null) { - daemonWorkers.get(worker).foreach { pid => - // tell daemon to kill worker by pid - val output = new DataOutputStream(daemon.getOutputStream) - output.writeInt(pid) - output.flush() - daemon.getOutputStream.flush() + synchronized { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) } - } else { - simpleWorkers.get(worker).foreach(_.destroy()) } worker.close() } + + def releaseWorker(worker: Socket) { + if (useDaemon) { + synchronized { + lastActivity = System.currentTimeMillis() + idleWorkers.enqueue(worker) + } + } else { + // Cleanup the worker socket. This will also cause the Python worker to exit. + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } } private object PythonWorkerFactory { val PROCESS_WAIT_TIMEOUT_MS = 10000 + val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index efc9009c088a8..6668797f5f8be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -17,6 +17,8 @@ package org.apache.spark.api.python +import java.nio.ByteOrder + import scala.collection.JavaConversions._ import scala.util.Failure import scala.util.Try @@ -28,6 +30,55 @@ import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ private[python] object SerDeUtil extends Logging { + // Unpickle array.array generated by Python 2.6 + class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { + // /* Description of types */ + // static struct arraydescr descriptors[] = { + // {'c', sizeof(char), c_getitem, c_setitem}, + // {'b', sizeof(char), b_getitem, b_setitem}, + // {'B', sizeof(char), BB_getitem, BB_setitem}, + // #ifdef Py_USING_UNICODE + // {'u', sizeof(Py_UNICODE), u_getitem, u_setitem}, + // #endif + // {'h', sizeof(short), h_getitem, h_setitem}, + // {'H', sizeof(short), HH_getitem, HH_setitem}, + // {'i', sizeof(int), i_getitem, i_setitem}, + // {'I', sizeof(int), II_getitem, II_setitem}, + // {'l', sizeof(long), l_getitem, l_setitem}, + // {'L', sizeof(long), LL_getitem, LL_setitem}, + // {'f', sizeof(float), f_getitem, f_setitem}, + // {'d', sizeof(double), d_getitem, d_setitem}, + // {'\0', 0, 0, 0} /* Sentinel */ + // }; + // TODO: support Py_UNICODE with 2 bytes + // FIXME: unpickle array of float is wrong in Pyrolite, so we reverse the + // machine code for float/double here to workaround it. + // we should fix this after Pyrolite fix them + val machineCodes: Map[Char, Int] = if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { + Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 3, 'h' -> 5, 'I' -> 7, 'i' -> 9, + 'L' -> 11, 'l' -> 13, 'f' -> 14, 'd' -> 16, 'u' -> 21 + ) + } else { + Map('c' -> 1, 'B' -> 0, 'b' -> 1, 'H' -> 2, 'h' -> 4, 'I' -> 6, 'i' -> 8, + 'L' -> 10, 'l' -> 12, 'f' -> 15, 'd' -> 17, 'u' -> 20 + ) + } + override def construct(args: Array[Object]): Object = { + if (args.length == 1) { + construct(args ++ Array("")) + } else if (args.length == 2 && args(1).isInstanceOf[String]) { + val typecode = args(0).asInstanceOf[String].charAt(0) + val data: String = args(1).asInstanceOf[String] + construct(typecode, machineCodes(typecode), data.getBytes("ISO-8859-1")) + } else { + super.construct(args) + } + } + } + + def initialize() = { + Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + } private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0fdb5ae3c2e40..5ed3575816a38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} -import java.lang.reflect.InvocationTargetException +import java.lang.reflect.{Modifier, InvocationTargetException} import java.net.URL import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -323,7 +323,9 @@ object SparkSubmit { } val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) - + if (!Modifier.isStatic(mainMethod.getModifiers)) { + throw new IllegalStateException("The main method in the given main class must be static") + } try { mainMethod.invoke(null, childArgs.toArray) } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dd903dc65d204..acae448a9c66f 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -360,7 +360,16 @@ private[spark] class Executor( if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => metrics.updateShuffleReadMetrics - tasksMetrics += ((taskRunner.taskId, metrics)) + if (isLocal) { + // JobProgressListener will hold an reference of it during + // onExecutorMetricsUpdate(), then JobProgressListener can not see + // the changes of metrics any more, so make a deep copy of it + val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics)) + tasksMetrics += ((taskRunner.taskId, copiedMetrics)) + } else { + // It will be copied by serialization + tasksMetrics += ((taskRunner.taskId, metrics)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index dcecb6beeea9b..e990c1da6730f 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -19,6 +19,7 @@ package org.apache.spark.network import java.io.{FileInputStream, RandomAccessFile, File, InputStream} import java.nio.ByteBuffer +import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode import com.google.common.io.ByteStreams @@ -66,8 +67,15 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt override def size: Long = length override def nioByteBuffer(): ByteBuffer = { - val channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) + var channel: FileChannel = null + try { + channel = new RandomAccessFile(file, "r").getChannel + channel.map(MapMode.READ_ONLY, offset, length) + } finally { + if (channel != null) { + channel.close() + } + } } override def inputStream(): InputStream = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6fcf9e31543ed..b2774dfc47553 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -507,11 +507,16 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null) { + val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { - case JobSucceeded => {} + case JobSucceeded => { + logInfo("Job %d finished: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) + } case JobFailed(exception: Exception) => - logInfo("Failed to run " + callSite.shortForm) + logInfo("Job %d failed: %s, took %f s".format + (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9)) throw exception } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5b5257269d92f..9a0cb1c6c6ccd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -292,7 +292,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") conf.set("spark.ui.filters", filterName) conf.set(s"spark.$filterName.params", filterParams) - JettyUtils.addFilters(scheduler.sc.ui.getHandlers, conf) + scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 513d74a08a47f..ee10aa061f4e9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.{Logging, SparkContext, SparkEnv} @@ -47,16 +46,17 @@ private[spark] class SimrSchedulerBackend( val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) + val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") logInfo("Writing to HDFS file: " + driverFilePath) logInfo("Writing Akka address: " + driverUrl) - logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress) + logInfo("Writing Spark UI Address: " + appUIAddress) // Create temporary file to prevent race condition where executors get empty driverUrl file val temp = fs.create(tmpPath, true) temp.writeUTF(driverUrl) temp.writeInt(maxCores) - temp.writeUTF(sc.ui.appUIAddress) + temp.writeUTF(appUIAddress) temp.close() // "Atomic" rename diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 06872ace2ecf4..2f45d192e1d4d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -67,8 +67,10 @@ private[spark] class SparkDeploySchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts) + val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val eventLogDir = sc.eventLogger.map(_.logDir) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) + appUIAddress, eventLogDir) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c8e708aa6b1bc..d868758a7f549 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue -import org.apache.spark.{TaskContext, Logging, SparkException} +import org.apache.spark.{TaskContext, Logging} import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b0754e3ce10db..c4dddb2d1037e 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -205,7 +205,6 @@ private[spark] object JsonProtocol { } def taskInfoToJson(taskInfo: TaskInfo): JValue = { - val accumUpdateMap = taskInfo.accumulables ("Task ID" -> taskInfo.taskId) ~ ("Index" -> taskInfo.index) ~ ("Attempt" -> taskInfo.attempt) ~ diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala new file mode 100644 index 0000000000000..f64e069cd1724 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Exception thrown when there is an exception in + * executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 79943766d0f0f..c76b7af18481d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -530,7 +530,12 @@ private[spark] object Utils extends Logging { if (address.isLoopbackAddress) { // Address resolves to something like 127.0.1.1, which happens on Debian; try to find // a better address using the local network interfaces - for (ni <- NetworkInterface.getNetworkInterfaces) { + // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order + // on unix-like system. On windows, it returns in index order. + // It's more proper to pick ip address following system output order. + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.toList + val reOrderedNetworkIFs = if (isWindows) activeNetworkIFs else activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { // We've found an address that looks reasonable! diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 2744894277ae8..2e3fc5ef0e336 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -21,7 +21,6 @@ import java.lang.ref.WeakReference import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials -import scala.language.postfixOps import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 4b1d280624c57..5265ba904032f 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -26,8 +26,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils -import scala.language.postfixOps - class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 28197657e9bad..3b833f2e41867 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -22,7 +22,6 @@ import java.util.concurrent.Semaphore import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global -import scala.language.postfixOps import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Timeouts diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index db2ad829a48f9..faba5508c906c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.scheduler +import org.mockito.Mockito._ +import org.mockito.Matchers.any + import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} + class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { - test("Calls executeOnCompleteCallbacks after failure") { + test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { @@ -45,6 +49,20 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } assert(TaskContextSuite.completed === true) } + + test("all TaskCompletionListeners should be called even if some fail") { + val context = new TaskContext(0, 0, 0) + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("blah")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("blah")) + + intercept[TaskCompletionListenerException] { + context.markTaskCompleted() + } + + verify(listener, times(1)).onTaskCompletion(any()) + } } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 038746d2eda4b..48790b59e7fbd 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -21,7 +21,6 @@ import java.net.ServerSocket import javax.servlet.http.HttpServletRequest import scala.io.Source -import scala.language.postfixOps import scala.util.{Failure, Success, Try} import org.eclipse.jetty.server.Server @@ -36,11 +35,25 @@ import scala.xml.Node class UISuite extends FunSuite { + /** + * Create a test SparkContext with the SparkUI enabled. + * It is safe to `get` the SparkUI directly from the SparkContext returned here. + */ + private def newSparkContext(): SparkContext = { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val sc = new SparkContext(conf) + assert(sc.ui.isDefined) + sc + } + ignore("basic ui visibility") { - withSpark(new SparkContext("local", "test")) { sc => + withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.appUIAddress).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) @@ -51,7 +64,7 @@ class UISuite extends FunSuite { } ignore("visibility at localhost:4040") { - withSpark(new SparkContext("local", "test")) { sc => + withSpark(newSparkContext()) { sc => // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString @@ -61,8 +74,8 @@ class UISuite extends FunSuite { } ignore("attaching a new tab") { - withSpark(new SparkContext("local", "test")) { sc => - val sparkUI = sc.ui + withSpark(newSparkContext()) { sc => + val sparkUI = sc.ui.get val newTab = new WebUITab(sparkUI, "foo") { attachPage(new WebUIPage("") { @@ -73,7 +86,7 @@ class UISuite extends FunSuite { } sparkUI.attachTab(newTab) eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress).mkString + val html = Source.fromURL(sparkUI.appUIAddress).mkString assert(!html.contains("random data that should not be present")) // check whether new page exists @@ -87,7 +100,7 @@ class UISuite extends FunSuite { } eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.appUIAddress.stripSuffix("/") + "/foo").mkString + val html = Source.fromURL(sparkUI.appUIAddress.stripSuffix("/") + "/foo").mkString // check whether new page exists assert(html.contains("magic")) } @@ -129,16 +142,20 @@ class UISuite extends FunSuite { } test("verify appUIAddress contains the scheme") { - withSpark(new SparkContext("local", "test")) { sc => - val uiAddress = sc.ui.appUIAddress - assert(uiAddress.equals("http://" + sc.ui.appUIHostPort)) + withSpark(newSparkContext()) { sc => + val ui = sc.ui.get + val uiAddress = ui.appUIAddress + val uiHostPort = ui.appUIHostPort + assert(uiAddress.equals("http://" + uiHostPort)) } } test("verify appUIAddress contains the port") { - withSpark(new SparkContext("local", "test")) { sc => - val splitUIAddress = sc.ui.appUIAddress.split(':') - assert(splitUIAddress(2).toInt == sc.ui.boundPort) + withSpark(newSparkContext()) { sc => + val ui = sc.ui.get + val splitUIAddress = ui.appUIAddress.split(':') + val boundPort = ui.boundPort + assert(splitUIAddress(2).toInt == boundPort) } } } diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index d48c8bde12905..a8e92e36fe0d8 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -44,9 +44,9 @@ # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") # ASF JIRA username -JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "pwendell") # ASF JIRA password -JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "35500") GITHUB_BASE = "https://github.com/apache/spark/pull" GITHUB_API_BASE = "https://api.github.com/repos/apache/spark" diff --git a/dev/mima b/dev/mima index f9b9b03538f15..40603166c21ae 100755 --- a/dev/mima +++ b/dev/mima @@ -25,11 +25,19 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update +rm -f .generated-mima* + +# Generate Mima Ignore is called twice, first with latest built jars +# on the classpath and then again with previous version jars on the classpath. +# Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath +# it did not process the new classes (which are in assembly jar). +./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore + echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? diff --git a/docs/README.md b/docs/README.md index 0a0126c5747d1..fdc89d2eb767a 100644 --- a/docs/README.md +++ b/docs/README.md @@ -23,8 +23,9 @@ The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllr To use the `jekyll` command, you will need to have Jekyll installed. The easiest way to do this is via a Ruby Gem, see the [jekyll installation instructions](http://jekyllrb.com/docs/installation). -If not already installed, you need to install `kramdown` with `sudo gem install kramdown`. -Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory +If not already installed, you need to install `kramdown` and `jekyll-redirect-from` Gems +with `sudo gem install kramdown jekyll-redirect-from`. +Execute `jekyll build` from the `docs/` directory. Compiling the site with Jekyll will create a directory called `_site` containing index.html as well as the rest of the compiled files. You can modify the default Jekyll build as follows: diff --git a/docs/_config.yml b/docs/_config.yml index 45b78fe724a50..d3ea2625c7448 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -1,5 +1,7 @@ -pygments: true +highlighter: pygments markdown: kramdown +gems: + - jekyll-redirect-from # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index b30ab1e5218c0..a53e8a775b71f 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -109,7 +109,7 @@
  • Hardware Provisioning
  • 3rd-Party Hadoop Distros
  • -
  • Building Spark with Maven
  • +
  • Building Spark
  • Contributing to Spark
  • diff --git a/docs/building-with-maven.md b/docs/building-spark.md similarity index 87% rename from docs/building-with-maven.md rename to docs/building-spark.md index bce7412c7d4c9..2378092d4a1a8 100644 --- a/docs/building-with-maven.md +++ b/docs/building-spark.md @@ -1,6 +1,7 @@ --- layout: global -title: Building Spark with Maven +title: Building Spark +redirect_from: "building-with-maven.html" --- * This will become a table of contents (this text will be scraped). @@ -159,4 +160,21 @@ then ship it over to the cluster. We are investigating the exact cause for this. The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +# Building with SBT +Maven is the official recommendation for packaging Spark, and is the "build of reference". +But SBT is supported for day-to-day development since it can provide much faster iterative +compilation. More advanced developers may wish to use SBT. + +The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables +can be set to control the SBT build. For example: + + sbt/sbt -Pyarn -Phadoop-2.3 compile + +# Speeding up Compilation with Zinc + +[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental +compiler. When run locally as a background process, it speeds up builds of Scala-based projects +like Spark. Developers who regularly recompile Spark with Maven will be the most interested in +Zinc. The project site gives instructions for building and running `zinc`; OS X users can +install it using `brew install zinc`. \ No newline at end of file diff --git a/docs/configuration.md b/docs/configuration.md index 36178efb97103..af16489a44281 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. + + spark.python.worker.reuse + true + + Reuse Python worker or not. If yes, it will use a fixed number of Python workers, + does not need to fork() a Python process for every tasks. It will be very useful + if there is large broadcast, then the broadcast will not be needed to transfered + from JVM to Python worker for every task. + + spark.executorEnv.[EnvironmentVariableName] (none) diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md index ab1023b8f1842..dd73e9dc54440 100644 --- a/docs/hadoop-third-party-distributions.md +++ b/docs/hadoop-third-party-distributions.md @@ -11,7 +11,7 @@ with these distributions: When compiling Spark, you'll need to specify the Hadoop version by defining the `hadoop.version` property. For certain versions, you will need to specify additional profiles. For more detail, -see the guide on [building with maven](building-with-maven.html#specifying-the-hadoop-version): +see the guide on [building with maven](building-spark.html#specifying-the-hadoop-version): mvn -Dhadoop.version=1.0.4 -DskipTests clean package mvn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package diff --git a/docs/index.md b/docs/index.md index 7fe6b43d32af7..e8ebadbd4e427 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,7 +12,7 @@ It also supports a rich set of higher-level tools including [Spark SQL](sql-prog Get Spark from the [downloads page](http://spark.apache.org/downloads.html) of the project website. This documentation is for Spark version {{site.SPARK_VERSION}}. The downloads page contains Spark packages for many popular HDFS versions. If you'd like to build Spark from -scratch, visit [building Spark with Maven](building-with-maven.html). +scratch, visit [Building Spark](building-spark.html). Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy to run locally on one machine --- all you need is to have `java` installed on your system `PATH`, @@ -105,7 +105,7 @@ options for deployment: * [3rd Party Hadoop Distributions](hadoop-third-party-distributions.html): using common Hadoop distributions * Integration with other storage systems: * [OpenStack Swift](storage-openstack-swift.html) -* [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system +* [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) **External Resources:** diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d8b22f3663d08..74bcc2eeb65f6 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -11,7 +11,7 @@ was added to Spark in version 0.6.0, and improved in subsequent releases. Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. Binary distributions can be downloaded from the Spark project website. -To build Spark yourself, refer to the [building with Maven guide](building-with-maven.html). +To build Spark yourself, refer to [Building Spark](building-spark.html). # Configuration @@ -155,6 +155,7 @@ For example: --driver-memory 4g \ --executor-memory 2g \ --executor-cores 1 \ + --queue thequeue \ lib/spark-examples*.jar \ 10 diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d83efa4bab324..c498b41c43380 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -128,7 +128,7 @@ feature parity with a HiveContext. -The specific variant of SQL that is used to parse queries can also be selected using the +The specific variant of SQL that is used to parse queries can also be selected using the `spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the @@ -139,7 +139,7 @@ default is "hiveql", though "sql" is also available. Since the HiveQL parser is Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section +Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section describes the various methods for loading data into a SchemaRDD. ## RDDs @@ -152,7 +152,7 @@ while writing your Spark application. The second method for creating SchemaRDDs is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows you to construct SchemaRDDs when the columns and their types are not known until runtime. - + ### Inferring the Schema Using Reflection
    @@ -193,7 +193,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
    Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -480,7 +480,7 @@ for name in names.collect(): [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. +of the original data. ### Loading Data Programmatically @@ -562,7 +562,7 @@ for teenName in teenNames.collect():
    -
    + ### Configuration @@ -808,7 +808,7 @@ memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove Note that if you call `cache` rather than `cacheTable`, tables will _not_ be cached using the in-memory columnar format, and therefore `cacheTable` is strongly recommended for this use case. -Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running +Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running `SET key=value` commands using SQL. @@ -881,10 +881,32 @@ To start the JDBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh -The default port the server listens on is 10000. To listen on customized host and port, please set -the `HIVE_SERVER2_THRIFT_PORT` and `HIVE_SERVER2_THRIFT_BIND_HOST` environment variables. You may -run `./sbin/start-thriftserver.sh --help` for a complete list of all available options. Now you can -use beeline to test the Thrift JDBC server: +This script accepts all `bin/spark-submit` command line options, plus a `--hiveconf` option to +specify Hive properties. You may run `./sbin/start-thriftserver.sh --help` for a complete list of +all available options. By default, the server listens on localhost:10000. You may override this +bahaviour via either environment variables, i.e.: + +{% highlight bash %} +export HIVE_SERVER2_THRIFT_PORT= +export HIVE_SERVER2_THRIFT_BIND_HOST= +./sbin/start-thriftserver.sh \ + --master \ + ... +``` +{% endhighlight %} + +or system properties: + +{% highlight bash %} +./sbin/start-thriftserver.sh \ + --hiveconf hive.server2.thrift.port= \ + --hiveconf hive.server2.thrift.bind.host= \ + --master + ... +``` +{% endhighlight %} + +Now you can use beeline to test the Thrift JDBC server: ./bin/beeline @@ -918,7 +940,6 @@ options. ## Migration Guide for Shark User ### Scheduling -s To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, users can set the `spark.sql.thriftserver.scheduler.pool` variable: @@ -931,7 +952,7 @@ SQL deprecates this property in favor of `spark.sql.shuffle.partitions`, whose d is 200. Users may customize this property via `SET`: SET spark.sql.shuffle.partitions=10; - SELECT page, count(*) c + SELECT page, count(*) c FROM logs_last_month_cached GROUP BY page ORDER BY c DESC LIMIT 10; @@ -1110,7 +1131,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c The range of numbers is from `-9223372036854775808` to `9223372036854775807`. - `FloatType`: Represents 4-byte single-precision floating point numbers. - `DoubleType`: Represents 8-byte double-precision floating point numbers. - - `DecimalType`: + - `DecimalType`: Represents arbitrary-precision signed decimal numbers. Backed internally by `java.math.BigDecimal`. A `BigDecimal` consists of an arbitrary precision integer unscaled value and a 32-bit integer scale. * String type - `StringType`: Represents character string values. * Binary type @@ -1140,7 +1161,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c
    All data types of Spark SQL are located in the package `org.apache.spark.sql`. -You can access them by doing +You can access them by doing {% highlight scala %} import org.apache.spark.sql._ {% endhighlight %} @@ -1232,7 +1253,7 @@ import org.apache.spark.sql._
    @@ -1246,7 +1267,7 @@ import org.apache.spark.sql._ -
    scala.collection.Seq ArrayType(elementType, [containsNull])
    - Note: The default value of containsNull is false. + Note: The default value of containsNull is true.
    StructType org.apache.spark.sql.Row + StructType(fields)
    Note: fields is a Seq of StructFields. Also, two fields with the same name are not allowed. @@ -1268,7 +1289,7 @@ import org.apache.spark.sql._ All data types of Spark SQL are located in the package of `org.apache.spark.sql.api.java`. To access or create a data type, -please use factory methods provided in +please use factory methods provided in `org.apache.spark.sql.api.java.DataType`. @@ -1358,7 +1379,7 @@ please use factory methods provided in @@ -1374,7 +1395,7 @@ please use factory methods provided in - @@ -1519,7 +1540,7 @@ from pyspark.sql import * -
    java.util.List DataType.createArrayType(elementType)
    - Note: The value of containsNull will be false
    + Note: The value of containsNull will be true
    DataType.createArrayType(elementType, containsNull).
    StructType org.apache.spark.sql.api.java + DataType.createStructType(fields)
    Note: fields is a List or an array of StructFields. Also, two fields with the same name are not allowed. @@ -1395,7 +1416,7 @@ please use factory methods provided in
    All data types of Spark SQL are located in the package of `pyspark.sql`. -You can access them by doing +You can access them by doing {% highlight python %} from pyspark.sql import * {% endhighlight %} @@ -1505,7 +1526,7 @@ from pyspark.sql import *
    list, tuple, or array ArrayType(elementType, [containsNull])
    - Note: The default value of containsNull is False. + Note: The default value of containsNull is True.
    StructType list or tuple + StructType(fields)
    Note: fields is a Seq of StructFields. Also, two fields with the same name are not allowed. diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index c6090d9ec30c7..379eb513d521e 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -108,7 +108,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download Spark source and follow the [instructions](building-with-maven.html) to build Spark with profile *-Pkinesis-asl*. +- Download Spark source and follow the [instructions](building-spark.html) to build Spark with profile *-Pkinesis-asl*. mvn -Pkinesis-asl -DskipTests clean package diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index bfd07593b92ed..5682e96aa8770 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -52,7 +52,7 @@ class UsageError(Exception): def parse_args(): parser = OptionParser( usage="spark-ec2 [options] " - + "\n\n can be: launch, destroy, login, stop, start, get-master", + + "\n\n can be: launch, destroy, login, stop, start, get-master, reboot-slaves", add_help_option=False) parser.add_option( "-h", "--help", action="help", @@ -950,6 +950,20 @@ def real_main(): subprocess.check_call( ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) + elif action == "reboot-slaves": + response = raw_input( + "Are you sure you want to reboot the cluster " + + cluster_name + " slaves?\n" + + "Reboot cluster slaves " + cluster_name + " (y/N): ") + if response == "y": + (master_nodes, slave_nodes) = get_existing_cluster( + conn, opts, cluster_name, die_on_error=False) + print "Rebooting slaves..." + for inst in slave_nodes: + if inst.state not in ["shutting-down", "terminated"]: + print "Rebooting " + inst.id + inst.reboot() + elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) print master_nodes[0].public_dns_name diff --git a/examples/pom.xml b/examples/pom.xml index 3f46c40464d3b..2b561857f9f33 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -203,6 +203,20 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-shade-plugin diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 72c3ab475b61f..4683e6eb966be 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -55,6 +55,8 @@ object DecisionTreeRunner { maxDepth: Int = 5, impurity: ImpurityType = Gini, maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, fracTest: Double = 0.2) def main(args: Array[String]) { @@ -75,6 +77,13 @@ object DecisionTreeRunner { opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) opt[Double]("fracTest") .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) @@ -179,7 +188,9 @@ object DecisionTreeRunner { impurity = impurityCalculator, maxDepth = params.maxDepth, maxBins = params.maxBins, - numClassesForClassification = numClasses) + numClassesForClassification = numClasses, + minInstancesPerNode = params.minInstancesPerNode, + minInfoGain = params.minInfoGain) val model = DecisionTree.train(training, strategy) println(model) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 8658ecf5abfab..7e478bed62da7 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -74,6 +74,20 @@ + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-surefire-plugin diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 614555a054dfb..257e2f3a36115 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -79,30 +79,43 @@ object PageRank extends Logging { def run[VD: ClassTag, ED: ClassTag]( graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = { - // Initialize the pagerankGraph with each edge attribute having + // Initialize the PageRank graph with each edge attribute having // weight 1/outDegree and each vertex with attribute 1.0. - val pagerankGraph: Graph[Double, Double] = graph + var rankGraph: Graph[Double, Double] = graph // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree .mapTriplets( e => 1.0 / e.srcAttr ) // Set the vertex attributes to the initial pagerank values - .mapVertices( (id, attr) => 1.0 ) - .cache() + .mapVertices( (id, attr) => resetProb ) - // Define the three functions needed to implement PageRank in the GraphX - // version of Pregel - def vertexProgram(id: VertexId, attr: Double, msgSum: Double): Double = - resetProb + (1.0 - resetProb) * msgSum - def sendMessage(edge: EdgeTriplet[Double, Double]) = - Iterator((edge.dstId, edge.srcAttr * edge.attr)) - def messageCombiner(a: Double, b: Double): Double = a + b - // The initial message received by all vertices in PageRank - val initialMessage = 0.0 + var iteration = 0 + var prevRankGraph: Graph[Double, Double] = null + while (iteration < numIter) { + rankGraph.cache() - // Execute pregel for a fixed number of iterations. - Pregel(pagerankGraph, initialMessage, numIter, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and + // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. + val rankUpdates = rankGraph.mapReduceTriplets[Double]( + e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _) + + // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices + // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the + // edge partitions. + prevRankGraph = rankGraph + rankGraph = rankGraph.joinVertices(rankUpdates) { + (id, oldRank, msgSum) => resetProb + (1.0 - resetProb) * msgSum + }.cache() + + rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices + logInfo(s"PageRank finished iteration $iteration.") + prevRankGraph.vertices.unpersist(false) + prevRankGraph.edges.unpersist(false) + + iteration += 1 + } + + rankGraph } /** diff --git a/make-distribution.sh b/make-distribution.sh index 9b012b9222db4..884659954a491 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -40,7 +40,7 @@ function exit_with_usage { echo "" echo "usage:" echo "./make-distribution.sh [--name] [--tgz] [--with-tachyon] " - echo "See Spark's \"Building with Maven\" doc for correct Maven options." + echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" exit 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 4343124f102a0..fa0fa69f38634 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -303,7 +303,9 @@ class PythonMLLibAPI extends Serializable { categoricalFeaturesInfoJMap: java.util.Map[Int, Int], impurityStr: String, maxDepth: Int, - maxBins: Int): DecisionTreeModel = { + maxBins: Int, + minInstancesPerNode: Int, + minInfoGain: Double): DecisionTreeModel = { val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) @@ -316,7 +318,9 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + minInstancesPerNode = minInstancesPerNode, + minInfoGain = minInfoGain) DecisionTree.train(data, strategy) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 486bdbfa9cb47..84d3c7cebd7c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -84,7 +84,7 @@ class LogisticRegressionWithSGD private ( extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { private val gradient = new LogisticGradient() - private val updater = new SimpleUpdater() + private val updater = new SquaredL2Updater() override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index ac6eaea3f43ad..5c1acca0ec532 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -76,16 +76,12 @@ class IndexedRowMatrix( } /** - * Computes the singular value decomposition of this matrix. + * Computes the singular value decomposition of this IndexedRowMatrix. * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. * - * There is no restriction on m, but we require `n^2` doubles to fit in memory. - * Further, n should be less than m. - - * The decomposition is computed by first computing A'A = V S^2 V', - * computing svd locally on that (since n x n is small), from which we recover S and V. - * Then we compute U via easy matrix multiplication as U = A * (V * S^-1). - * Note that this approach requires `O(n^3)` time on the master node. + * The cost and implementation of this method is identical to that in + * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] + * With the addition of indices. * * At most k largest non-zero singular values and associated vectors are returned. * If there are k such values, then the dimensions of the return will be: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index d1309b2b20f54..c7f2576c822b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -87,17 +87,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val maxDepth = strategy.maxDepth require(maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") - // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1 - val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1) - // Initialize an array to hold parent impurity calculations for each node. - val parentImpurities = new Array[Double](maxNumNodesPlus1) - // dummy value for top node (updated during first split calculation) - val nodes = new Array[Node](maxNumNodesPlus1) // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 + val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") // TODO: Calculate memory usage more precisely. val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) @@ -120,80 +114,35 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * beforehand and is not used in later levels. */ + var topNode: Node = null // set on first iteration var level = 0 var break = false while (level <= maxDepth && !break) { - logDebug("#####################################") logDebug("level = " + level) logDebug("#####################################") // Find best split for all nodes at a level. timer.start("findBestSplits") - val splitsStatsForLevel: Array[(Split, InformationGainStats)] = - DecisionTree.findBestSplits(treeInput, parentImpurities, - metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) + val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput, + metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = Node.startIndexInLevel(level) - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - val nodeIndex = levelNodeIndexOffset + index - - // Extract info for this node (index) at the current level. - timer.start("extractNodeInfo") - val split = nodeSplitStats._1 - val stats = nodeSplitStats._2 - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - nodes(nodeIndex) = node - timer.stop("extractNodeInfo") - - if (level != 0) { - // Set parent. - val parentNodeIndex = Node.parentIndex(nodeIndex) - if (Node.isLeftChild(nodeIndex)) { - nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) - } else { - nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) - } - } - // Extract info for nodes at the next lower level. - timer.start("extractInfoForLowerLevels") - if (level < maxDepth) { - val leftChildIndex = Node.leftChildIndex(nodeIndex) - val leftImpurity = stats.leftImpurity - logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity) - parentImpurities(leftChildIndex) = leftImpurity - - val rightChildIndex = Node.rightChildIndex(nodeIndex) - val rightImpurity = stats.rightImpurity - logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity) - parentImpurities(rightChildIndex) = rightImpurity - } - timer.stop("extractInfoForLowerLevels") - logDebug("final best split = " + split) + if (level == 0) { + topNode = tmpTopNode } - require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length) - // Check whether all the nodes at the current level at leaves. - val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) - logDebug("all leaf = " + allLeaf) - if (allLeaf) { - break = true // no more tree construction - } else { - level += 1 + if (doneTraining) { + break = true + logDebug("done training") } + + level += 1 } logDebug("#####################################") logDebug("Extracting tree model") logDebug("#####################################") - // Initialize the top or root node of the tree. - val topNode = nodes(1) - // Build the full tree using the node info calculated in the level-wise best split calculations. - topNode.build(nodes) - timer.stop("total") logInfo("Internal timing for DecisionTree:") @@ -408,24 +357,26 @@ object DecisionTree extends Serializable with Logging { * multiple groups if the level-wise training task could lead to memory overflow. * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree + * @param topNode Root node of the tree (or invalid node when training first level). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array (over nodes) of splits with best split for each node at a given level. + * @return (root, doneTraining) where: + * root = Root node (which is newly created on the first iteration), + * doneTraining = true if no more internal nodes were created. */ private[tree] def findBestSplits( input: RDD[TreePoint], - parentImpurities: Array[Double], metadata: DecisionTreeMetadata, level: Int, - nodes: Array[Node], + topNode: Node, splits: Array[Array[Split]], bins: Array[Array[Bin]], maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { + timer: TimeTracker = new TimeTracker): (Node, Boolean) = { + // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, @@ -434,18 +385,18 @@ object DecisionTree extends Serializable with Logging { // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) - var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 + var doneTraining = true while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, - nodes, splits, bins, timer, numGroups, groupIndex) - bestSplits = Array.concat(bestSplits, bestSplitsForGroup) + val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, + topNode, splits, bins, timer, numGroups, groupIndex) + doneTraining = doneTraining && doneTrainingGroup groupIndex += 1 } - bestSplits + (topNode, doneTraining) // Not first iteration, so topNode was already set. } else { - findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) + findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer) } } @@ -585,27 +536,27 @@ object DecisionTree extends Serializable with Logging { * Returns an array of optimal splits for a group of nodes at a given level * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param nodes Array of all nodes in the tree. Used for matching data points to nodes. + * @param topNode Root node of the tree (or invalid node when training first level). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. - * @return array of splits with best splits for all nodes at a given level. + * @return (root, doneTraining) where: + * root = Root node (which is newly created on the first iteration), + * doneTraining = true if no more internal nodes were created. */ private def findBestSplitsPerGroup( input: RDD[TreePoint], - parentImpurities: Array[Double], metadata: DecisionTreeMetadata, level: Int, - nodes: Array[Node], + topNode: Node, splits: Array[Array[Split]], bins: Array[Array[Bin]], timer: TimeTracker, numGroups: Int = 1, - groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { + groupIndex: Int = 0): (Node, Boolean) = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -662,7 +613,7 @@ object DecisionTree extends Serializable with Logging { 0 } else { val globalNodeIndex = - predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) + predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures) globalNodeIndex - globalNodeIndexOffset } } @@ -705,56 +656,81 @@ object DecisionTree extends Serializable with Logging { // Calculate best splits for all nodes at a given level timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - // Iterating over all nodes at this level + // On the first iteration, we need to get and return the newly created root node. + var newTopNode: Node = topNode + + // Iterate over all nodes at this level var nodeIndex = 0 + var internalNodeCount = 0 while (nodeIndex < numNodes) { - val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) - logDebug("node impurity = " + nodeImpurity) - bestSplits(nodeIndex) = - binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits) - logDebug("best split = " + bestSplits(nodeIndex)._1) + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits) + logDebug("best split = " + split) + + val globalNodeIndex = globalNodeIndexOffset + nodeIndex + + // Extract info for this node at the current level. + val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth) + val node = + new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + + if (!isLeaf) { + internalNodeCount += 1 + } + if (level == 0) { + newTopNode = node + } else { + // Set parent. + val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode) + if (Node.isLeftChild(globalNodeIndex)) { + parentNode.leftNode = Some(node) + } else { + parentNode.rightNode = Some(node) + } + } + if (level < metadata.maxDepth) { + logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) + + ", impurity = " + stats.rightImpurity) + } + nodeIndex += 1 } timer.stop("chooseSplits") - bestSplits + val doneTraining = internalNodeCount == 0 + (newTopNode, doneTraining) } /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - topImpurity: Double, level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { - val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count - val totalCount = leftCount + rightCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return InformationGainStats.invalidInformationGainStats } + val totalCount = leftCount + rightCount + val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) - // impurity of parent node - val impurity = if (level > 0) { - topImpurity - } else { - parentNodeAgg.calculate() - } - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) + val impurity = parentNodeAgg.calculate() val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -764,28 +740,51 @@ object DecisionTree extends Serializable with Logging { val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + // if information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return InformationGainStats.invalidInformationGainStats + } + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + } + + /** + * Calculate predict value for current node, given stats of any split. + * Note that this function is called only once for each node. + * @param leftImpurityCalculator left node aggregates for a split + * @param rightImpurityCalculator right node aggregates for a node + * @return predict value for current node + */ + private def calculatePredict( + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator): Predict = { + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + val predict = parentNodeAgg.predict + val prob = parentNodeAgg.prob(predict) + + new Predict(predict, prob) } /** * Find the best split for a node. * @param binAggregates Bin statistics. * @param nodeIndex Index for node to split in this (level, group). - * @param nodeImpurity Impurity of the node (nodeIndex). * @return tuple for best split: (Split, information gain) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, - nodeImpurity: Double, level: Int, metadata: DecisionTreeMetadata, - splits: Array[Array[Split]]): (Split, InformationGainStats) = { + splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { - logDebug("node impurity = " + nodeImpurity) + // calculate predict only once + var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). - Range(0, metadata.numFeatures).map { featureIndex => + val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex => val numSplits = metadata.numSplits(featureIndex) if (metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. @@ -803,8 +802,8 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -816,8 +815,8 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -887,8 +886,8 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - val gainStats = - calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -898,13 +897,17 @@ object DecisionTree extends Serializable with Logging { (bestFeatureSplit, bestFeatureGainStats) } }.maxBy(_._2.gain) + + assert(predict.isDefined, "must calculate predict for each node") + + (bestSplit, bestSplitStats, predict.get) } /** * Get the number of values to be stored per node in the bin aggregates. */ - private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = { - val totalBins = metadata.numBins.sum + private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = { + val totalBins = metadata.numBins.map(_.toLong).sum if (metadata.isClassification) { metadata.numClasses * totalBins } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 23f74d5360fe5..caaccbfb8ad16 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. + * @param minInstancesPerNode Minimum number of instances each child must have after split. + * Default value is 1. If a split cause left or right child + * to have less than minInstancesPerNode, + * this split will not be considered as a valid split. + * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. + * If a split has less information gain than minInfoGain, + * this split will not be considered as a valid split. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. */ @@ -61,11 +68,18 @@ class Strategy ( val maxBins: Int = 32, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val minInstancesPerNode: Int = 1, + val minInfoGain: Double = 0.0, val maxMemoryInMB: Int = 256) extends Serializable { if (algo == Classification) { require(numClassesForClassification >= 2) } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") + val isMulticlassClassification = algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 866d85a79bea1..61a94246711bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator( * Offset for each feature for calculating indices into the [[allStats]] array. */ private val featureOffsets: Array[Int] = { - def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { - if (isUnordered(featureIndex)) { - total + 2 * numBins(featureIndex) - } else { - total + numBins(featureIndex) - } - } - Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } /** @@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator( s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + s" but was called for ordered feature $featureIndex.") val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) - (baseOffset, baseOffset + numBins(featureIndex) * statsSize) + (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index e95add7558bcf..b6d49e5555b1a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -45,7 +45,10 @@ private[tree] class DecisionTreeMetadata( val unorderedFeatures: Set[Int], val numBins: Array[Int], val impurity: Impurity, - val quantileStrategy: QuantileStrategy) extends Serializable { + val quantileStrategy: QuantileStrategy, + val maxDepth: Int, + val minInstancesPerNode: Int, + val minInfoGain: Double) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -127,7 +130,8 @@ private[tree] object DecisionTreeMetadata { new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy) + strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, + strategy.minInstancesPerNode, strategy.minInfoGain) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 0594fd0749d21..271b2c4ad813e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * Predict values for the given data set using the model trained. * * @param features RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction + * @return RDD of predictions for each of the given data points */ def predict(features: RDD[Vector]): RDD[Double] = { features.map(x => predict(x)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index fb12298e0f5d3..f3e2619bd8ba0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity - * @param predict predicted value - * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double, - val predict: Double, - val prob: Double = 0.0) extends Serializable { + val rightImpurity: Double) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" + .format(gain, impurity, leftImpurity, rightImpurity) } } + + +private[tree] object InformationGainStats { + /** + * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5b8a4cbed2306..5f0095d23c7ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -55,6 +55,8 @@ class Node ( * build the left node and right nodes if not leaf * @param nodes array of nodes */ + @deprecated("build should no longer be used since trees are constructed on-the-fly in training", + "1.2.0") def build(nodes: Array[Node]): Unit = { logDebug("building node " + id + " at level " + Node.indexToLevel(id)) logDebug("id = " + id + ", split = " + split) @@ -93,6 +95,23 @@ class Node ( } } + /** + * Returns a deep copy of the subtree rooted at this node. + */ + private[tree] def deepCopy(): Node = { + val leftNodeCopy = if (leftNode.isEmpty) { + None + } else { + Some(leftNode.get.deepCopy()) + } + val rightNodeCopy = if (rightNode.isEmpty) { + None + } else { + Some(rightNode.get.deepCopy()) + } + new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + } + /** * Get the number of nodes in tree below this node, including leaf nodes. * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. @@ -190,4 +209,22 @@ private[tree] object Node { */ def startIndexInLevel(level: Int): Int = 1 << level + /** + * Traces down from a root node to get the node with the given node index. + * This assumes the node exists. + */ + def getNode(nodeIndex: Int, rootNode: Node): Node = { + var tmpNode: Node = rootNode + var levelsToGo = indexToLevel(nodeIndex) + while (levelsToGo > 0) { + if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { + tmpNode = tmpNode.leftNode.get + } else { + tmpNode = tmpNode.rightNode.get + } + levelsToGo -= 1 + } + tmpNode + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala new file mode 100644 index 0000000000000..d8476b5cd7bc7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Predicted value for a node + * @param predict predicted value + * @param prob probability of the label (classification only) + */ +private[tree] class Predict( + val predict: Double, + val prob: Double = 0.0) extends Serializable { + + override def toString = { + "predict = %f, prob = %f".format(predict, prob) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index 50fb48b40de3d..b7a85f58544a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType /** * :: DeveloperApi :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 862178694a50e..e954baaf7d91e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -43,7 +43,7 @@ object LogisticRegressionSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -58,12 +58,15 @@ object LogisticRegressionSuite { } class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { - def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + def validatePrediction( + predictions: Seq[Double], + input: Seq[LabeledPoint], + expectedAcc: Double = 0.83) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } // At least 83% of the predictions should be on. - ((input.length - numOffPredictions).toDouble / input.length) should be > 0.83 + ((input.length - numOffPredictions).toDouble / input.length) should be > expectedAcc } // Test if we can correctly learn A, B where Y = logistic(A + B*X) @@ -155,6 +158,41 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + test("logistic regression with initial weights and non-default regularization parameter") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Vectors.dense(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithSGD().setIntercept(true) + lr.optimizer. + setStepSize(10.0). + setNumIterations(10). + setRegParam(1.0) + + val model = lr.run(testRDD, initialWeights) + + // Test the weights + assert(model.weights(0) ~== -430000.0 relTol 20000.0) + assert(model.intercept ~== 370000.0 relTol 20000.0) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8) + } + test("logistic regression with initial weights with LBFGS") { val nPoints = 10000 val A = 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 69482f2acbb40..2b2e579b992f6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} +import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -270,18 +270,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 0) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode: Node, doneTraining: Boolean) = + DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10) - val split = bestSplits(0)._1 + val split = rootNode.split.get assert(split.categories === List(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) - val stats = bestSplits(0)._2 + val stats = rootNode.stats.get assert(stats.gain > 0) - assert(stats.predict === 1) - assert(stats.prob === 0.6) + assert(rootNode.predict === 1) assert(stats.impurity > 0.2) } @@ -302,18 +301,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - val split = bestSplits(0)._1 + val split = rootNode.split.get assert(split.categories.length === 1) assert(split.categories.contains(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) - val stats = bestSplits(0)._2 + val stats = rootNode.stats.get assert(stats.gain > 0) - assert(stats.predict === 0.6) + assert(rootNode.predict === 0.6) assert(stats.impurity > 0.2) } @@ -354,13 +353,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) } test("Binary classification stump with fixed label 1 for Gini") { @@ -380,14 +382,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -407,14 +412,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 0) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -434,14 +442,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, - new Array[Node](0), splits, bins, 10) - assert(bestSplits.length === 1) - assert(bestSplits(0)._1.feature === 0) - assert(bestSplits(0)._2.gain === 0) - assert(bestSplits(0)._2.leftImpurity === 0) - assert(bestSplits(0)._2.rightImpurity === 0) - assert(bestSplits(0)._2.predict === 1) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + + val stats = rootNode.stats.get + assert(stats.gain === 0) + assert(stats.leftImpurity === 0) + assert(stats.rightImpurity === 0) + assert(rootNode.predict === 1) } test("Second level node building with vs. without groups") { @@ -457,40 +468,46 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClassesForClassification = 2, maxBins = 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val nodes: Array[Node] = new Array[Node](8) - nodes(1) = modelOneNode.topNode - nodes(1).leftNode = None - nodes(1).rightNode = None - - val parentImpurities = Array(0, 0.5, 0.5, 0.5) + val rootNodeCopy1 = modelOneNode.topNode.deepCopy() + val rootNodeCopy2 = modelOneNode.topNode.deepCopy() // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, - splits, bins, 10) - assert(bestSplits.length === 2) - assert(bestSplits(0)._2.gain > 0) - assert(bestSplits(1)._2.gain > 0) + val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, + rootNodeCopy1, splits, bins, 10) + assert(rootNode.leftNode.nonEmpty) + assert(rootNode.rightNode.nonEmpty) + val children1 = new Array[Node](2) + children1(0) = rootNode.leftNode.get + children1(1) = rootNode.rightNode.get // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, - nodes, splits, bins, 0) - assert(bestSplitsWithGroups.length === 2) - assert(bestSplitsWithGroups(0)._2.gain > 0) - assert(bestSplitsWithGroups(1)._2.gain > 0) + val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, + rootNodeCopy2, splits, bins, 0) + assert(rootNode2.leftNode.nonEmpty) + assert(rootNode2.rightNode.nonEmpty) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get // Verify whether the splits obtained using single group and multiple group level // construction strategies are the same. - for (i <- 0 until bestSplits.length) { - assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) - assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) - assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) - assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) - assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) - assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict === children2(i).predict) } } @@ -506,15 +523,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1)) - assert(bestSplit.featureType === Categorical) + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1)) + assert(split.featureType === Categorical) } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { @@ -571,16 +587,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) - - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1)) - assert(bestSplit.featureType === Categorical) - val gain = bestSplits(0)._2 + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1)) + assert(split.featureType === Categorical) + + val gain = rootNode.stats.get assert(gain.leftImpurity === 0) assert(gain.rightImpurity === 0) } @@ -598,16 +614,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - - assert(bestSplit.feature === 1) - assert(bestSplit.featureType === Continuous) - assert(bestSplit.threshold > 1980) - assert(bestSplit.threshold < 2020) + val split = rootNode.split.get + assert(split.feature === 1) + assert(split.featureType === Continuous) + assert(split.threshold > 1980) + assert(split.threshold < 2020) } @@ -625,16 +639,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) - - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplit.feature === 1) - assert(bestSplit.featureType === Continuous) - assert(bestSplit.threshold > 1980) - assert(bestSplit.threshold < 2020) + val split = rootNode.split.get + assert(split.feature === 1) + assert(split.featureType === Continuous) + assert(split.threshold > 1980) + assert(split.threshold < 2020) } test("Multiclass classification stump with 10-ary (ordered) categorical features") { @@ -650,15 +662,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, - new Array[Node](0), splits, bins, 10) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) - assert(bestSplits.length === 1) - val bestSplit = bestSplits(0)._1 - assert(bestSplit.feature === 0) - assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(1.0)) - assert(bestSplit.featureType === Categorical) + val split = rootNode.split.get + assert(split.feature === 0) + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) } test("Multiclass classification tree with 10-ary (ordered) categorical features," + @@ -674,6 +685,88 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { validateClassifier(model, arr, 0.6) } + test("split must satisfy min instances per node requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2) + + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + + // test for findBestSplits when no valid split can be found + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val gain = rootNode.stats.get + assert(gain == InformationGainStats.invalidInformationGainStats) + } + + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, + maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + numClassesForClassification = 2, minInstancesPerNode = 2) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val split = rootNode.split.get + val gain = rootNode.stats.get + assert(split.feature == 1) + assert(gain != InformationGainStats.invalidInformationGainStats) + } + + test("split must satisfy min info gain requirements") { + val arr = new Array[LabeledPoint](3) + arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) + arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) + arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, minInfoGain = 1.0) + + val model = DecisionTree.train(input, strategy) + assert(model.topNode.isLeaf) + assert(model.topNode.predict == 0.0) + val predicts = input.map(p => model.predict(p.features)).collect() + predicts.foreach { predict => + assert(predict == 0.0) + } + + // test for findBestSplits when no valid split can be found + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, + null, splits, bins, 10) + + val gain = rootNode.stats.get + assert(gain == InformationGainStats.invalidInformationGainStats) + } } object DecisionTreeSuite { @@ -699,13 +792,16 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - if (i < 600) { - val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) - arr(i) = lp + val label = if (i < 100) { + 0.0 + } else if (i < 500) { + 1.0 + } else if (i < 900) { + 0.0 } else { - val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) - arr(i) = lp + 1.0 } + arr(i) = new LabeledPoint(label, Vectors.dense(i.toDouble, 1000.0 - i)) } arr } diff --git a/pom.xml b/pom.xml index 64fb1e57e30e0..520aed3806937 100644 --- a/pom.xml +++ b/pom.xml @@ -134,6 +134,7 @@ 0.3.6 3.0.0 1.7.6 + 0.7.1 1.8.3 1.1.0 @@ -621,6 +622,7 @@ org.apache.avro avro-mapred ${avro.version} + ${avro.mapred.classifier} io.netty @@ -839,7 +841,6 @@ -unchecked -deprecation -feature - -language:postfixOps -Xms1024m @@ -899,7 +900,7 @@ true ${session.executionRootDirectory} 1 - 0 + false @@ -1109,6 +1110,7 @@ 2.2.0 2.5.0 + hadoop2 @@ -1118,6 +1120,7 @@ 2.3.0 2.5.0 0.9.0 + hadoop2 @@ -1127,6 +1130,7 @@ 2.4.0 2.5.0 0.9.0 + hadoop2 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 0f5d71afcf616..39f8ba4745737 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -30,6 +30,12 @@ object MimaBuild { def excludeMember(fullName: String) = Seq( ProblemFilters.exclude[MissingMethodProblem](fullName), + // Sometimes excluded methods have default arguments and + // they are translated into public methods/fields($default$) in generated + // bytecode. It is not possible to exhustively list everything. + // But this should be okay. + ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"), + ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"), ProblemFilters.exclude[MissingFieldProblem](fullName), ProblemFilters.exclude[IncompatibleResultTypeProblem](fullName), ProblemFilters.exclude[IncompatibleMethTypeProblem](fullName), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 46b78bd5c7061..2f1e05dfcc7b1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -37,14 +37,8 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("graphx") - ) ++ - // This is @DeveloperAPI, but Mima still gives false-positives: - MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ - Seq( - // This is @Experimental, but Mima still gives false-positives: - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync") ) + case v if v.startsWith("1.1") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 45f6d2973ea90..ab9f8ba120e83 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -187,7 +187,7 @@ object OldDeps { Some("org.apache.spark" % fullId % "1.1.0") } - def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", scalaVersion := "2.10.4", retrieveManaged := true, @@ -337,7 +337,7 @@ object TestSettings { javaOptions in Test += "-Dspark.test.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.ports.maxRetries=100", - javaOptions in Test += "-Dspark.ui.port=0", + javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala new file mode 100644 index 0000000000000..3d43c35299555 --- /dev/null +++ b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/NonASCIICharacterChecker.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.scalastyle + +import java.util.regex.Pattern + +import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} + +import scalariform.lexer.Token +import scalariform.parser.CompilationUnit + +class NonASCIICharacterChecker extends ScalariformChecker { + val errorKey: String = "non.ascii.character.disallowed" + + override def verify(ast: CompilationUnit): List[ScalastyleError] = { + ast.tokens.filter(hasNonAsciiChars).map(x => PositionError(x.offset)).toList + } + + private def hasNonAsciiChars(x: Token) = + x.rawText.trim.nonEmpty && !Pattern.compile( """\p{ASCII}+""", Pattern.DOTALL) + .matcher(x.text.trim).matches() + +} diff --git a/python/docs/Makefile b/python/docs/Makefile new file mode 100644 index 0000000000000..8a1324eecd325 --- /dev/null +++ b/python/docs/Makefile @@ -0,0 +1,179 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = _build + +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.8.2.1-src.zip) + +# User-friendly check for sphinx-build +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) +$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) +endif + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " xml to make Docutils-native XML files" + @echo " pseudoxml to make pseudoxml-XML files for display purposes" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + +clean: + rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/pyspark.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/pyspark.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/pyspark" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/pyspark" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +latexpdfja: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through platex and dvipdfmx..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." + +xml: + $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml + @echo + @echo "Build finished. The XML files are in $(BUILDDIR)/xml." + +pseudoxml: + $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml + @echo + @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." diff --git a/python/docs/conf.py b/python/docs/conf.py new file mode 100644 index 0000000000000..c368cf81a003b --- /dev/null +++ b/python/docs/conf.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- +# +# pyspark documentation build configuration file, created by +# sphinx-quickstart on Thu Aug 28 15:17:47 2014. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys +import os + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath('.')) + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.viewcode', + 'epytext', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'PySpark' +copyright = u'2014, Author' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = '1.1' +# The full version, including alpha/beta/rc tags. +release = '' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +#language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ['_build'] + +# The reST default role (used for this markup: `text`) to use for all +# documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +#keep_warnings = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'default' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Add any extra paths that contain custom files (such as robots.txt or +# .htaccess) here, relative to this directory. These files are copied +# directly to the root of the documentation. +#html_extra_path = [] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = 'pysparkdoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +#'papersize': 'letterpaper', + +# The font size ('10pt', '11pt' or '12pt'). +#'pointsize': '10pt', + +# Additional stuff for the LaTeX preamble. +#'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + ('index', 'pyspark.tex', u'pyspark Documentation', + u'Author', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ('index', 'pyspark', u'pyspark Documentation', + [u'Author'], 1) +] + +# If true, show URL addresses after external links. +#man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ('index', 'pyspark', u'pyspark Documentation', + u'Author', 'pyspark', 'One line description of project.', + 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +#texinfo_appendices = [] + +# If false, no module index is generated. +#texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +#texinfo_no_detailmenu = False + + +# -- Options for Epub output ---------------------------------------------- + +# Bibliographic Dublin Core info. +epub_title = u'pyspark' +epub_author = u'Author' +epub_publisher = u'Author' +epub_copyright = u'2014, Author' + +# The basename for the epub file. It defaults to the project name. +#epub_basename = u'pyspark' + +# The HTML theme for the epub output. Since the default themes are not optimized +# for small screen space, using the same theme for HTML and epub output is +# usually not wise. This defaults to 'epub', a theme designed to save visual +# space. +#epub_theme = 'epub' + +# The language of the text. It defaults to the language option +# or en if the language is not set. +#epub_language = '' + +# The scheme of the identifier. Typical schemes are ISBN or URL. +#epub_scheme = '' + +# The unique identifier of the text. This can be a ISBN number +# or the project homepage. +#epub_identifier = '' + +# A unique identification for the text. +#epub_uid = '' + +# A tuple containing the cover image and cover page html template filenames. +#epub_cover = () + +# A sequence of (type, uri, title) tuples for the guide element of content.opf. +#epub_guide = () + +# HTML files that should be inserted before the pages created by sphinx. +# The format is a list of tuples containing the path and title. +#epub_pre_files = [] + +# HTML files shat should be inserted after the pages created by sphinx. +# The format is a list of tuples containing the path and title. +#epub_post_files = [] + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] + +# The depth of the table of contents in toc.ncx. +#epub_tocdepth = 3 + +# Allow duplicate toc entries. +#epub_tocdup = True + +# Choose between 'default' and 'includehidden'. +#epub_tocscope = 'default' + +# Fix unsupported image types using the PIL. +#epub_fix_images = False + +# Scale large images. +#epub_max_image_width = 0 + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#epub_show_urls = 'inline' + +# If false, no index is generated. +#epub_use_index = True diff --git a/python/docs/epytext.py b/python/docs/epytext.py new file mode 100644 index 0000000000000..61d731bff570d --- /dev/null +++ b/python/docs/epytext.py @@ -0,0 +1,27 @@ +import re + +RULES = ( + (r"<[\w.]+>", r""), + (r"L{([\w.()]+)}", r":class:`\1`"), + (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), + (r"C{([\w.()]+)}", r":class:`\1`"), + (r"[IBCM]{(.+)}", r"`\1`"), + ('pyspark.rdd.RDD', 'RDD'), +) + +def _convert_epytext(line): + """ + >>> _convert_epytext("L{A}") + :class:`A` + """ + line = line.replace('@', ':') + for p, sub in RULES: + line = re.sub(p, sub, line) + return line + +def _process_docstring(app, what, name, obj, options, lines): + for i in range(len(lines)): + lines[i] = _convert_epytext(lines[i]) + +def setup(app): + app.connect("autodoc-process-docstring", _process_docstring) diff --git a/python/docs/index.rst b/python/docs/index.rst new file mode 100644 index 0000000000000..25b3f9bd93e63 --- /dev/null +++ b/python/docs/index.rst @@ -0,0 +1,37 @@ +.. pyspark documentation master file, created by + sphinx-quickstart on Thu Aug 28 15:17:47 2014. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to PySpark API reference! +=================================== + +Contents: + +.. toctree:: + :maxdepth: 2 + + pyspark + pyspark.sql + pyspark.mllib + + +Core classes: +--------------- + + :class:`pyspark.SparkContext` + + Main entry point for Spark functionality. + + :class:`pyspark.RDD` + + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + diff --git a/python/docs/make.bat b/python/docs/make.bat new file mode 100644 index 0000000000000..adad44fd7536a --- /dev/null +++ b/python/docs/make.bat @@ -0,0 +1,242 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/python/docs/modules.rst b/python/docs/modules.rst new file mode 100644 index 0000000000000..183564659fbcf --- /dev/null +++ b/python/docs/modules.rst @@ -0,0 +1,7 @@ +. += + +.. toctree:: + :maxdepth: 4 + + pyspark diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst new file mode 100644 index 0000000000000..e95d19e97f151 --- /dev/null +++ b/python/docs/pyspark.mllib.rst @@ -0,0 +1,77 @@ +pyspark.mllib package +===================== + +Submodules +---------- + +pyspark.mllib.classification module +----------------------------------- + +.. automodule:: pyspark.mllib.classification + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.clustering module +------------------------------- + +.. automodule:: pyspark.mllib.clustering + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.linalg module +--------------------------- + +.. automodule:: pyspark.mllib.linalg + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.random module +--------------------------- + +.. automodule:: pyspark.mllib.random + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.recommendation module +----------------------------------- + +.. automodule:: pyspark.mllib.recommendation + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.regression module +------------------------------- + +.. automodule:: pyspark.mllib.regression + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.stat module +------------------------- + +.. automodule:: pyspark.mllib.stat + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.tree module +------------------------- + +.. automodule:: pyspark.mllib.tree + :members: + :undoc-members: + :show-inheritance: + +pyspark.mllib.util module +------------------------- + +.. automodule:: pyspark.mllib.util + :members: + :undoc-members: + :show-inheritance: diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst new file mode 100644 index 0000000000000..a68bd62433085 --- /dev/null +++ b/python/docs/pyspark.rst @@ -0,0 +1,18 @@ +pyspark package +=============== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 1 + + pyspark.mllib + pyspark.sql + +Contents +-------- + +.. automodule:: pyspark + :members: + :undoc-members: diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst new file mode 100644 index 0000000000000..65b3650ae10ab --- /dev/null +++ b/python/docs/pyspark.sql.rst @@ -0,0 +1,10 @@ +pyspark.sql module +================== + +Module contents +--------------- + +.. automodule:: pyspark.sql + :members: + :undoc-members: + :show-inheritance: diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 5c7c9cc161dff..f124dc6c07575 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -78,6 +78,9 @@ def value(self): return self._value def unpersist(self, blocking=False): + """ + Delete cached copies of this broadcast on the executors. + """ self._jbroadcast.unpersist(blocking) os.unlink(self.path) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 80e51d1a583a0..32dda3888c62d 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -52,35 +52,19 @@ import itertools from copy_reg import _extension_registry, _inverted_registry, _extension_cache import new -import dis import traceback +import platform -#relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) -GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] +PyImp = platform.python_implementation() -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) import logging cloudLog = logging.getLogger("Cloud.Transport") -try: - import ctypes -except (MemoryError, ImportError): - logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True) - ctypes = None - PyObject_HEAD = None -else: - - # for reading internal structures - PyObject_HEAD = [ - ('ob_refcnt', ctypes.c_size_t), - ('ob_type', ctypes.c_void_p), - ] +if PyImp == "PyPy": + # register builtin type in `new` + new.method = types.MethodType try: from cStringIO import StringIO @@ -225,6 +209,8 @@ def save_function(self, obj, name=None, pack=struct.pack): if themodule: self.modules.add(themodule) + if getattr(themodule, name, None) is obj: + return self.save_global(obj, name) if not self.savedDjangoEnv: #hack for django - if we detect the settings module, we transport it @@ -306,44 +292,28 @@ def save_function_tuple(self, func, forced_imports): # create a skeleton function object and memoize it save(_make_skel_func) - save((code, len(closure), base_globals)) + save((code, closure, base_globals)) write(pickle.REDUCE) self.memoize(func) # save the rest of the func data needed by _fill_function save(f_globals) save(defaults) - save(closure) save(dct) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple @staticmethod - def extract_code_globals(co): + def extract_code_globals(code): """ Find all globals names read or written to by codeblock co """ - code = co.co_code - names = co.co_names - out_names = set() - - n = len(code) - i = 0 - extended_arg = 0 - while i < n: - op = code[i] - - i = i+1 - if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg - extended_arg = 0 - i = i+2 - if op == EXTENDED_ARG: - extended_arg = oparg*65536L - if op in GLOBAL_OPS: - out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - return out_names + names = set(code.co_names) + if code.co_consts: # see if nested function have any global refs + for const in code.co_consts: + if type(const) is types.CodeType: + names |= CloudPickler.extract_code_globals(const) + return names def extract_func_data(self, func): """ @@ -354,10 +324,7 @@ def extract_func_data(self, func): # extract all global ref's func_global_refs = CloudPickler.extract_code_globals(code) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: - if type(const) is types.CodeType and const.co_names: - func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const)) + # process all variables referenced by global environment f_globals = {} for var in func_global_refs: @@ -396,6 +363,12 @@ def get_contents(cell): return (code, f_globals, defaults, closure, dct, base_globals) + def save_builtin_function(self, obj): + if obj.__module__ is "__builtin__": + return self.save_global(obj) + return self.save_function(obj) + dispatch[types.BuiltinFunctionType] = save_builtin_function + def save_global(self, obj, name=None, pack=struct.pack): write = self.write memo = self.memo @@ -435,7 +408,7 @@ def save_global(self, obj, name=None, pack=struct.pack): try: klass = getattr(themodule, name) except AttributeError, a: - #print themodule, name, obj, type(obj) + # print themodule, name, obj, type(obj) raise pickle.PicklingError("Can't pickle builtin %s" % obj) else: raise @@ -480,7 +453,6 @@ def save_global(self, obj, name=None, pack=struct.pack): write(pickle.GLOBAL + modname + '\n' + name + '\n') self.memoize(obj) dispatch[types.ClassType] = save_global - dispatch[types.BuiltinFunctionType] = save_global dispatch[types.TypeType] = save_global def save_instancemethod(self, obj): @@ -551,23 +523,39 @@ def save_property(self, obj): dispatch[property] = save_property def save_itemgetter(self, obj): - """itemgetter serializer (needed for namedtuple support) - a bit of a pain as we need to read ctypes internals""" - class ItemGetterType(ctypes.Structure): - _fields_ = PyObject_HEAD + [ - ('nitems', ctypes.c_size_t), - ('item', ctypes.py_object) - ] - - - obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents - return self.save_reduce(operator.itemgetter, - obj.item if obj.nitems > 1 else (obj.item,)) - - if PyObject_HEAD: + """itemgetter serializer (needed for namedtuple support)""" + class Dummy: + def __getitem__(self, item): + return item + items = obj(Dummy()) + if not isinstance(items, tuple): + items = (items, ) + return self.save_reduce(operator.itemgetter, items) + + if type(operator.itemgetter) is type: dispatch[operator.itemgetter] = save_itemgetter + def save_attrgetter(self, obj): + """attrgetter serializer""" + class Dummy(object): + def __init__(self, attrs, index=None): + self.attrs = attrs + self.index = index + def __getattribute__(self, item): + attrs = object.__getattribute__(self, "attrs") + index = object.__getattribute__(self, "index") + if index is None: + index = len(attrs) + attrs.append(item) + else: + attrs[index] = ".".join([attrs[index], item]) + return type(self)(attrs, index) + attrs = [] + obj(Dummy(attrs)) + return self.save_reduce(operator.attrgetter, tuple(attrs)) + if type(operator.attrgetter) is type: + dispatch[operator.attrgetter] = save_attrgetter def save_reduce(self, func, args, state=None, listitems=None, dictitems=None, obj=None): @@ -660,11 +648,11 @@ def save_file(self, obj): if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") - if obj.name == '': + if obj is sys.stdout: return self.save_reduce(getattr, (sys,'stdout'), obj=obj) - if obj.name == '': + if obj is sys.stderr: return self.save_reduce(getattr, (sys,'stderr'), obj=obj) - if obj.name == '': + if obj is sys.stdin: raise pickle.PicklingError("Cannot pickle standard input") if hasattr(obj, 'isatty') and obj.isatty(): raise pickle.PicklingError("Cannot pickle files that map to tty objects") @@ -873,8 +861,7 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) - -def _fill_function(func, globals, defaults, closure, dict): +def _fill_function(func, globals, defaults, dict): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ @@ -882,49 +869,28 @@ def _fill_function(func, globals, defaults, closure, dict): func.func_defaults = defaults func.func_dict = dict - if len(closure) != len(func.func_closure): - raise pickle.UnpicklingError("closure lengths don't match up") - for i in range(len(closure)): - _change_cell_value(func.func_closure[i], closure[i]) - return func -def _make_skel_func(code, num_closures, base_globals = None): +def _make_cell(value): + return (lambda: value).func_closure[0] + +def _reconstruct_closure(values): + return tuple([_make_cell(v) for v in values]) + +def _make_skel_func(code, closures, base_globals = None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other func attributes (e.g. func_globals) are empty. """ - #build closure (cells): - if not ctypes: - raise Exception('ctypes failed to import; cannot build function') - - cellnew = ctypes.pythonapi.PyCell_New - cellnew.restype = ctypes.py_object - cellnew.argtypes = (ctypes.py_object,) - dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures))) + closure = _reconstruct_closure(closures) if closures else None if base_globals is None: base_globals = {} base_globals['__builtins__'] = __builtins__ return types.FunctionType(code, base_globals, - None, None, dummy_closure) - -# this piece of opaque code is needed below to modify 'cell' contents -cell_changer_code = new.code( - 1, 1, 2, 0, - ''.join([ - chr(dis.opmap['LOAD_FAST']), '\x00\x00', - chr(dis.opmap['DUP_TOP']), - chr(dis.opmap['STORE_DEREF']), '\x00\x00', - chr(dis.opmap['RETURN_VALUE']) - ]), - (), (), ('newval',), '', 'cell_changer', 1, '', ('c',), () -) - -def _change_cell_value(cell, newval): - """ Changes the contents of 'cell' object to newval """ - return new.function(cell_changer_code, {}, None, (), (cell,))(newval) + None, None, closure) + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 84bc0a3b7ccd0..a17f2c1203d36 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -from collections import namedtuple from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,7 @@ from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD +from pyspark.traceback_utils import CallSite, first_spark_call from py4j.java_collections import ListConverter @@ -53,7 +53,7 @@ class SparkContext(object): """ Main entry point for Spark functionality. A SparkContext represents the - connection to a Spark cluster, and can be used to create L{RDD}s and + connection to a Spark cluster, and can be used to create L{RDD} and broadcast variables on that cluster. """ @@ -99,11 +99,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ - if rdd._extract_concise_traceback() is not None: - self._callsite = rdd._extract_concise_traceback() - else: - tempNamedTuple = namedtuple("Callsite", "function file linenum") - self._callsite = tempNamedTuple(function=None, file=None, linenum=None) + self._callsite = first_spark_call() or CallSite(None, None, None) SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, @@ -214,6 +210,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile + SparkContext._jvm.SerDeUtil.initialize() if instance: if (SparkContext._active_spark_context and @@ -331,12 +328,16 @@ def pickleFile(self, name, minPartitions=None): return RDD(self._jsc.objectFile(name, minPartitions), self, BatchedSerializer(PickleSerializer())) - def textFile(self, name, minPartitions=None): + def textFile(self, name, minPartitions=None, use_unicode=True): """ Read a text file from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI, and return it as an RDD of Strings. + If use_unicode is False, the strings will be kept as `str` (encoding + as `utf-8`), which is faster and smaller than unicode. (Added in + Spark 1.2) + >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: ... testFile.write("Hello world!") @@ -346,9 +347,9 @@ def textFile(self, name, minPartitions=None): """ minPartitions = minPartitions or min(self.defaultParallelism, 2) return RDD(self._jsc.textFile(name, minPartitions), self, - UTF8Deserializer()) + UTF8Deserializer(use_unicode)) - def wholeTextFiles(self, path, minPartitions=None): + def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): """ Read a directory of text files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system @@ -356,6 +357,10 @@ def wholeTextFiles(self, path, minPartitions=None): key-value pair, where the key is the path of each file, the value is the content of each file. + If use_unicode is False, the strings will be kept as `str` (encoding + as `utf-8`), which is faster and smaller than unicode. (Added in + Spark 1.2) + For example, if you have the following files:: hdfs://a-hdfs-path/part-00000 @@ -386,7 +391,7 @@ def wholeTextFiles(self, path, minPartitions=None): """ minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, - PairDeserializer(UTF8Deserializer(), UTF8Deserializer())) + PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 22ab8d30c0ae3..64d6202acb27d 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -23,6 +23,7 @@ import sys import traceback import time +import gc from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN @@ -42,25 +43,10 @@ def worker(sock): """ Called by a worker process after the fork(). """ - # Redirect stdout to stderr - os.dup2(2, 1) - sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 - signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) - # Blocks until the socket is closed by draining the input stream - # until it raises an exception or returns EOF. - def waitSocketClose(sock): - try: - while True: - # Empty string is returned upon EOF (and only then). - if sock.recv(4096) == '': - return - except: - pass - # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. @@ -68,17 +54,13 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - # Acknowledge that the fork was successful - write_int(os.getpid(), outfile) - outfile.flush() worker_main(infile, outfile) except SystemExit as exc: - exit_code = exc.code + exit_code = compute_real_exit_code(exc.code) finally: outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) + if exit_code: + os._exit(exit_code) # Cleanup zombie children @@ -102,6 +84,7 @@ def manager(): listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) + sys.stdout.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) @@ -114,8 +97,9 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + reuse = os.environ.get("SPARK_REUSE_WORKER") + # Initialization complete - sys.stdout.close() try: while True: try: @@ -167,7 +151,19 @@ def handle_sigterm(*args): # in child process listen_sock.close() try: - worker(sock) + # Acknowledge that the fork was successful + outfile = sock.makefile("w") + write_int(os.getpid(), outfile) + outfile.flush() + outfile.close() + while True: + worker(sock) + if not reuse: + # wait for closing + while sock.recv(1024): + pass + break + gc.collect() except: traceback.print_exc() os._exit(1) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index bb60d3d0c8463..68f6033616726 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -21,7 +21,7 @@ from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD from pyspark.mllib.linalg import SparseVector -from pyspark.serializers import Serializer +from pyspark.serializers import FramedSerializer """ @@ -451,18 +451,16 @@ def _serialize_rating(r): return ba -class RatingDeserializer(Serializer): +class RatingDeserializer(FramedSerializer): - def loads(self, stream): - length = struct.unpack("!i", stream.read(4))[0] - ba = stream.read(length) - res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4) + def loads(self, string): + res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4) return int(res[0]), int(res[1]), res[2] def load_stream(self, stream): while True: try: - yield self.loads(stream) + yield self._read_with_length(stream) except struct.error: return except EOFError: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index ccc000ac70ba6..5b13ab682bbfc 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -138,7 +138,8 @@ class DecisionTree(object): @staticmethod def trainClassifier(data, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32): + impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for classification. @@ -154,6 +155,9 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -164,13 +168,14 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "classification", numClasses, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model) @staticmethod def trainRegressor(data, categoricalFeaturesInfo, - impurity="variance", maxDepth=5, maxBins=32): + impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, + minInfoGain=0.0): """ Train a DecisionTreeModel for regression. @@ -185,6 +190,9 @@ def trainRegressor(data, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. + :param minInstancesPerNode: Min number of instances required at child nodes to create + the parent split + :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel """ sc = data.context @@ -195,7 +203,7 @@ def trainRegressor(data, categoricalFeaturesInfo, model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( dataBytes._jrdd, "regression", 0, categoricalFeaturesInfoJMap, - impurity, maxDepth, maxBins) + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) dataBytes.unpersist() return DecisionTreeModel(sc, model) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 5667154cb84a8..cb09c191bed71 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,13 +18,11 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from collections import namedtuple from itertools import chain, ifilter, imap import operator import os import sys import shlex -import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -45,6 +43,7 @@ from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ get_used_memory, ExternalSorter +from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -81,57 +80,6 @@ def portable_hash(x): return hash(x) -def _extract_concise_traceback(): - """ - This function returns the traceback info for a callsite, returns a dict - with function name, file name and line number - """ - tb = traceback.extract_stack() - callsite = namedtuple("Callsite", "function file linenum") - if len(tb) == 0: - return None - file, line, module, what = tb[len(tb) - 1] - sparkpath = os.path.dirname(file) - first_spark_frame = len(tb) - 1 - for i in range(0, len(tb)): - file, line, fun, what = tb[i] - if file.startswith(sparkpath): - first_spark_frame = i - break - if first_spark_frame == 0: - file, line, fun, what = tb[0] - return callsite(function=fun, file=file, linenum=line) - sfile, sline, sfun, swhat = tb[first_spark_frame] - ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] - return callsite(function=sfun, file=ufile, linenum=uline) - -_spark_stack_depth = 0 - - -class _JavaStackTrace(object): - - def __init__(self, sc): - tb = _extract_concise_traceback() - if tb is not None: - self._traceback = "%s at %s:%s" % ( - tb.function, tb.file, tb.linenum) - else: - self._traceback = "Error! Could not extract traceback info" - self._context = sc - - def __enter__(self): - global _spark_stack_depth - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(self._traceback) - _spark_stack_depth += 1 - - def __exit__(self, type, value, tb): - global _spark_stack_depth - _spark_stack_depth -= 1 - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(None) - - class BoundedFloat(float): """ Bounded value is generated by approximate job, with confidence and low @@ -353,7 +301,7 @@ def func(iterator): return ifilter(f, iterator) return self.mapPartitions(func, True) - def distinct(self): + def distinct(self, numPartitions=None): """ Return a new RDD containing the distinct elements in this RDD. @@ -361,7 +309,7 @@ def distinct(self): [1, 2, 3] """ return self.map(lambda x: (x, None)) \ - .reduceByKey(lambda x, _: x) \ + .reduceByKey(lambda x, _: x, numPartitions) \ .map(lambda (x, _): x) def sample(self, withReplacement, fraction, seed=None): @@ -704,7 +652,7 @@ def collect(self): """ Return a list that contains all of the elements in this RDD. """ - with _JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) @@ -1060,6 +1008,7 @@ def top(self, num, key=None): Get the top N elements from a RDD. Note: It returns the list sorted in descending order. + >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2) @@ -1514,7 +1463,7 @@ def add_shuffle_key(split, iterator): keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True - with _JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 55e6cf3308611..44ac5642836e0 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -110,6 +110,9 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __repr__(self): + return "<%s object>" % self.__class__.__name__ + class FramedSerializer(Serializer): @@ -144,6 +147,8 @@ def _write_with_length(self, obj, stream): def _read_with_length(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError obj = stream.read(length) if obj == "": raise EOFError @@ -355,7 +360,8 @@ class PickleSerializer(FramedSerializer): def dumps(self, obj): return cPickle.dumps(obj, 2) - loads = cPickle.loads + def loads(self, obj): + return cPickle.loads(obj) class CloudPickleSerializer(PickleSerializer): @@ -374,8 +380,11 @@ class MarshalSerializer(FramedSerializer): This serializer is faster than PickleSerializer but supports fewer datatypes. """ - dumps = marshal.dumps - loads = marshal.loads + def dumps(self, obj): + return marshal.dumps(obj) + + def loads(self, obj): + return marshal.loads(obj) class AutoSerializer(FramedSerializer): @@ -429,18 +438,24 @@ class UTF8Deserializer(Serializer): Deserializes streams written by String.getBytes. """ + def __init__(self, use_unicode=False): + self.use_unicode = use_unicode + def loads(self, stream): length = read_int(stream) - return stream.read(length).decode('utf8') + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError + s = stream.read(length) + return s.decode("utf-8") if self.use_unicode else s def load_stream(self, stream): - while True: - try: + try: + while True: yield self.loads(stream) - except struct.error: - return - except EOFError: - return + except struct.error: + return + except EOFError: + return def read_long(stream): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 49829f5280a5f..ce597cbe91e15 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -68,6 +68,11 @@ def _get_local_dirs(sub): return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] +# global stats +MemoryBytesSpilled = 0L +DiskBytesSpilled = 0L + + class Aggregator(object): """ @@ -313,10 +318,12 @@ def _spill(self): It will dump the data in batch for better performance. """ + global MemoryBytesSpilled, DiskBytesSpilled path = self._get_spill_dir(self.spills) if not os.path.exists(path): os.makedirs(path) + used_memory = get_used_memory() if not self.pdata: # The data has not been partitioned, it will iterator the # dataset once, write them into different files, has no @@ -334,6 +341,7 @@ def _spill(self): self.serializer.dump_stream([(k, v)], streams[h]) for s in streams: + DiskBytesSpilled += s.tell() s.close() self.data.clear() @@ -346,9 +354,11 @@ def _spill(self): # dump items in batch self.serializer.dump_stream(self.pdata[i].iteritems(), f) self.pdata[i].clear() + DiskBytesSpilled += os.path.getsize(p) self.spills += 1 gc.collect() # release the memory as much as possible + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 def iteritems(self): """ Return all merged items as iterator """ @@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) - self._spilled_bytes = 0 def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False): Sort the elements in iterator, do external sort when the memory goes above the limit. """ + global MemoryBytesSpilled, DiskBytesSpilled batch = 10 chunks, current_chunk = [], [] iterator = iter(iterator) @@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False): if len(chunk) < batch: break - if get_used_memory() > self.memory_limit: + used_memory = get_used_memory() + if used_memory > self.memory_limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) with open(path, 'w') as f: self.serializer.dump_stream(current_chunk, f) - self._spilled_bytes += os.path.getsize(path) chunks.append(self.serializer.load_stream(open(path))) current_chunk = [] + gc.collect() + MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + DiskBytesSpilled += os.path.getsize(path) elif not chunks: batch = min(batch * 2, 10000) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 53eea6d6cf3ba..8f6dbab240c7b 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -30,6 +30,7 @@ from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer from pyspark.storagelevel import StorageLevel +from pyspark.traceback_utils import SCCallSiteSync from itertools import chain, ifilter, imap @@ -288,7 +289,7 @@ class StructType(DataType): """Spark SQL StructType The data type representing rows. - A StructType object comprises a list of L{StructField}s. + A StructType object comprises a list of L{StructField}. """ @@ -903,7 +904,7 @@ class SQLContext(object): """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as + A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as tables, execute SQL over tables, cache tables, and read parquet files. """ @@ -993,7 +994,7 @@ def registerFunction(self, name, f, returnType=StringType()): str(returnType)) def inferSchema(self, rdd): - """Infer and apply a schema to an RDD of L{Row}s. + """Infer and apply a schema to an RDD of L{Row}. We peek at the first row of the RDD to determine the fields' names and types. Nested collections are supported, which include array, @@ -1046,7 +1047,7 @@ def inferSchema(self, rdd): def applySchema(self, rdd, schema): """ - Applies the given schema to the given RDD of L{tuple} or L{list}s. + Applies the given schema to the given RDD of L{tuple} or L{list}. These tuples or lists can contain complex nested structures like lists, maps or nested rows. @@ -1122,7 +1123,7 @@ def applySchema(self, rdd, schema): batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) - return SchemaRDD(srdd, self) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1134,8 +1135,8 @@ def registerRDDAsTable(self, rdd, tableName): >>> sqlCtx.registerRDDAsTable(srdd, "table1") """ if (rdd.__class__ is SchemaRDD): - jschema_rdd = rdd._jschema_rdd - self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) + srdd = rdd._jschema_rdd.baseSchemaRDD() + self._ssql_ctx.registerRDDAsTable(srdd, tableName) else: raise ValueError("Can only register SchemaRDD as table") @@ -1151,7 +1152,7 @@ def parquetFile(self, path): >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) + jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) def jsonFile(self, path, schema=None): @@ -1182,6 +1183,7 @@ def jsonFile(self, path, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( @@ -1192,6 +1194,7 @@ def jsonFile(self, path, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", @@ -1207,11 +1210,11 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - jschema_rdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path) else: scala_datatype = self._ssql_ctx.parseDataType(str(schema)) - jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(jschema_rdd, self) + srdd = self._ssql_ctx.jsonFile(path, scala_datatype) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def jsonRDD(self, rdd, schema=None): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. @@ -1232,6 +1235,7 @@ def jsonRDD(self, rdd, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) >>> sqlCtx.registerRDDAsTable(srdd3, "table2") >>> srdd4 = sqlCtx.sql( @@ -1242,6 +1246,7 @@ def jsonRDD(self, rdd, schema=None): Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> schema = StructType([ ... StructField("field2", StringType(), True), ... StructField("field3", @@ -1275,11 +1280,11 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: scala_datatype = self._ssql_ctx.parseDataType(str(schema)) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(jschema_rdd, self) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. @@ -1290,7 +1295,7 @@ def sql(self, sqlQuery): >>> srdd2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self) def table(self, tableName): """Returns the specified table as a L{SchemaRDD}. @@ -1301,7 +1306,7 @@ def table(self, tableName): >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1353,7 +1358,7 @@ def hiveql(self, hqlQuery): warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", DeprecationWarning) - return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) + return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self) def hql(self, hqlQuery): """ @@ -1524,6 +1529,8 @@ class SchemaRDD(RDD): def __init__(self, jschema_rdd, sql_ctx): self.sql_ctx = sql_ctx self._sc = sql_ctx._sc + clsName = jschema_rdd.getClass().getName() + assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD" self._jschema_rdd = jschema_rdd self._id = None self.is_cached = False @@ -1540,7 +1547,7 @@ def _jrdd(self): L{pyspark.rdd.RDD} super class (map, filter, etc.). """ if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.javaToPython() + self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() return self._lazy_jrdd def id(self): @@ -1548,6 +1555,18 @@ def id(self): self._id = self._jrdd.id() return self._id + def limit(self, num): + """Limit the result count to the number specified. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.limit(2).collect() + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + >>> srdd.limit(0).collect() + [] + """ + rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() + return SchemaRDD(rdd, self.sql_ctx) + def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. @@ -1598,7 +1617,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.schema().toString()) + return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) def schemaString(self): """Returns the output schema in the tree format.""" @@ -1624,15 +1643,39 @@ def count(self): return self._jschema_rdd.count() def collect(self): - """ - Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows in this RDD. - Each object in the list is on Row, the fields can be accessed as + Each object in the list is a Row, the fields can be accessed as attributes. + + Unlike the base RDD implementation of collect, this implementation + leverages the query optimizer to perform a collect on the SchemaRDD, + which supports features such as filter pushdown. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.collect() + [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - rows = RDD.collect(self) + with SCCallSiteSync(self.context) as css: + bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() cls = _create_cls(self.schema()) - return map(cls, rows) + return map(cls, self._collect_iterator_through_file(bytesInJava)) + + def take(self, num): + """Take the first num rows of the RDD. + + Each object in the list is a Row, the fields can be accessed as + attributes. + + Unlike the base RDD implementation of take, this implementation + leverages the query optimizer to perform a collect on a SchemaRDD, + which supports features such as filter pushdown. + + >>> srdd = sqlCtx.inferSchema(rdd) + >>> srdd.take(2) + [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] + """ + return self.limit(num).collect() # Convert each object in the RDD to a Row with the right class # for this SchemaRDD, so that fields can be accessed as attributes. @@ -1649,8 +1692,6 @@ def mapPartitionsWithIndex(self, f, preservesPartitioning=False): rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) schema = self.schema() - import pickle - pickle.loads(pickle.dumps(schema)) def applySchema(_, it): cls = _create_cls(schema) @@ -1687,17 +1728,18 @@ def isCheckpointed(self): def getCheckpointFile(self): checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): + if checkpointFile.isPresent(): return checkpointFile.get() - else: - return None def coalesce(self, numPartitions, shuffle=False): rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) return SchemaRDD(rdd, self.sql_ctx) - def distinct(self): - rdd = self._jschema_rdd.distinct() + def distinct(self, numPartitions=None): + if numPartitions is None: + rdd = self._jschema_rdd.distinct() + else: + rdd = self._jschema_rdd.distinct(numPartitions) return SchemaRDD(rdd, self.sql_ctx) def intersection(self, other): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bb84ebe72cb24..0b3854347ad2e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,6 +31,7 @@ import time import zipfile import random +from platform import python_implementation if sys.version_info[:2] <= (2, 6): import unittest2 as unittest @@ -41,9 +42,11 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.files import SparkFiles -from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer +from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ + CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType +from pyspark import shuffle _have_scipy = False _have_numpy = False @@ -136,17 +139,17 @@ def test_external_sort(self): random.shuffle(l) sorter = ExternalSorter(1) self.assertEquals(sorted(l), list(sorter.sorted(l))) - self.assertGreater(sorter._spilled_bytes, 0) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, 0) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertGreater(sorter._spilled_bytes, last) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertGreater(sorter._spilled_bytes, last) - last = sorter._spilled_bytes + self.assertGreater(shuffle.DiskBytesSpilled, last) + last = shuffle.DiskBytesSpilled self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), list(sorter.sorted(l, key=lambda x: -x, reverse=True))) - self.assertGreater(sorter._spilled_bytes, last) + self.assertGreater(shuffle.DiskBytesSpilled, last) def test_external_sort_in_rdd(self): conf = SparkConf().set("spark.python.worker.memory", "1m") @@ -168,15 +171,46 @@ def test_namedtuple(self): p2 = loads(dumps(p1, 2)) self.assertEquals(p1, p2) - -# Regression test for SPARK-3415 -class CloudPickleTest(unittest.TestCase): + def test_itemgetter(self): + from operator import itemgetter + ser = CloudPickleSerializer() + d = range(10) + getter = itemgetter(1) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + getter = itemgetter(0, 3) + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + def test_attrgetter(self): + from operator import attrgetter + ser = CloudPickleSerializer() + + class C(object): + def __getattr__(self, item): + return item + d = C() + getter = attrgetter("a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("a", "b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + d.e = C() + getter = attrgetter("e.a") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + getter = attrgetter("e.a", "e.b") + getter2 = ser.loads(ser.dumps(getter)) + self.assertEqual(getter(d), getter2(d)) + + # Regression test for SPARK-3415 def test_pickling_file_handles(self): - from pyspark.cloudpickle import dumps - from StringIO import StringIO - from pickle import load + ser = CloudPickleSerializer() out1 = sys.stderr - out2 = load(StringIO(dumps(out1))) + out2 = ser.loads(ser.dumps(out1)) self.assertEquals(out1, out2) @@ -553,6 +587,14 @@ def test_repartitionAndSortWithinPartitions(self): self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + def test_distinct(self): + rdd = self.sc.parallelize((1, 2, 3)*10, 10) + self.assertEquals(rdd.getNumPartitions(), 10) + self.assertEquals(rdd.distinct().count(), 3) + result = rdd.distinct(5) + self.assertEquals(result.getNumPartitions(), 5) + self.assertEquals(result.count(), 3) + class TestSQL(PySparkTestCase): @@ -574,6 +616,43 @@ def test_broadcast_in_udf(self): [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + srdd = self.sqlCtx.jsonRDD(rdd) + srdd.count() + srdd.collect() + srdd.schemaString() + srdd.schema() + + # cache and checkpoint + self.assertFalse(srdd.is_cached) + srdd.persist() + srdd.unpersist() + srdd.cache() + self.assertTrue(srdd.is_cached) + self.assertFalse(srdd.isCheckpointed()) + self.assertEqual(None, srdd.getCheckpointFile()) + + srdd = srdd.coalesce(2, True) + srdd = srdd.repartition(3) + srdd = srdd.distinct() + srdd.intersection(srdd) + self.assertEqual(2, srdd.count()) + + srdd.registerTempTable("temp") + srdd = self.sqlCtx.sql("select foo from temp") + srdd.count() + srdd.collect() + + def test_distinct(self): + rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) + srdd = self.sqlCtx.jsonRDD(rdd) + self.assertEquals(srdd.getNumPartitions(), 10) + self.assertEquals(srdd.distinct().count(), 3) + result = srdd.distinct(5) + self.assertEquals(result.getNumPartitions(), 5) + self.assertEquals(result.count(), 3) + class TestIO(PySparkTestCase): @@ -861,8 +940,40 @@ def test_oldhadoop(self): conf=input_conf).collect()) self.assertEqual(old_dataset, dict_data) - @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed") def test_newhadoop(self): + basepath = self.tempdir.name + data = [(1, ""), + (1, "a"), + (2, "bcdf")] + self.sc.parallelize(data).saveAsNewAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text") + result = sorted(self.sc.newAPIHadoopFile( + basepath + "/newhadoop/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + self.assertEqual(result, data) + + conf = { + "mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.Text", + "mapred.output.dir": basepath + "/newdataset/" + } + self.sc.parallelize(data).saveAsNewAPIHadoopDataset(conf) + input_conf = {"mapred.input.dir": basepath + "/newdataset/"} + new_dataset = sorted(self.sc.newAPIHadoopRDD( + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text", + conf=input_conf).collect()) + self.assertEqual(new_dataset, data) + + def test_newhadoop_with_array(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays array_data = [(1, array('d')), @@ -1127,11 +1238,46 @@ def run(): except OSError: self.fail("daemon had been killed") + # run a normal job + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + def test_fd_leak(self): N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(range(100), 1) + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write("Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + class TestSparkSubmit(unittest.TestCase): diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py new file mode 100644 index 0000000000000..bb8646df2b0bf --- /dev/null +++ b/python/pyspark/traceback_utils.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple +import os +import traceback + + +CallSite = namedtuple("CallSite", "function file linenum") + + +def first_spark_call(): + """ + Return a CallSite representing the first Spark call in the current call stack. + """ + tb = traceback.extract_stack() + if len(tb) == 0: + return None + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return CallSite(function=fun, file=file, linenum=line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] + return CallSite(function=sfun, file=ufile, linenum=uline) + + +class SCCallSiteSync(object): + """ + Helper for setting the spark context call site. + + Example usage: + from pyspark.context import SCCallSiteSync + with SCCallSiteSync() as css: + + """ + + _spark_stack_depth = 0 + + def __init__(self, sc): + call_site = first_spark_call() + if call_site is not None: + self._call_site = "%s at %s:%s" % ( + call_site.function, call_site.file, call_site.linenum) + else: + self._call_site = "Error! Could not extract traceback info" + self._context = sc + + def __enter__(self): + if SCCallSiteSync._spark_stack_depth == 0: + self._context._jsc.setCallSite(self._call_site) + SCCallSiteSync._spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + SCCallSiteSync._spark_stack_depth -= 1 + if SCCallSiteSync._spark_stack_depth == 0: + self._context._jsc.setCallSite(None) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6805063e06798..252176ac65fec 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,16 +23,14 @@ import time import socket import traceback -# CloudPickler needs to be imported so that depicklers are registered using the -# copy_reg module. + from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ CompressedSerializer - +from pyspark import shuffle pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -52,6 +50,11 @@ def main(infile, outfile): if split_index == -1: # for unit tests return + # initialize global state + shuffle.MemoryBytesSpilled = 0 + shuffle.DiskBytesSpilled = 0 + _accumulatorRegistry.clear() + # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir @@ -69,9 +72,14 @@ def main(infile, outfile): ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = ser._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) + if bid >= 0: + value = ser._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + else: + bid = - bid - 1 + _broadcastRegistry.remove(bid) + _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() @@ -92,6 +100,9 @@ def main(infile, outfile): exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) + write_long(shuffle.MemoryBytesSpilled, outfile) + write_long(shuffle.DiskBytesSpilled, outfile) + # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) diff --git a/python/run-tests b/python/run-tests index d98840de59d2c..a67e5a99fbdcc 100755 --- a/python/run-tests +++ b/python/run-tests @@ -85,6 +85,27 @@ run_test "pyspark/mllib/tests.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" +# Try to test with PyPy +if [ $(which pypy) ]; then + export PYSPARK_PYTHON="pypy" + echo "Testing with PyPy version:" + $PYSPARK_PYTHON --version + + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + run_test "pyspark/sql.py" + # These tests are included in the module-level docs, and so must + # be handled on a higher level rather than within the python file. + export PYSPARK_DOC_TEST=1 + run_test "pyspark/broadcast.py" + run_test "pyspark/accumulators.py" + run_test "pyspark/serializers.py" + unset PYSPARK_DOC_TEST + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" +fi + if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green echo "Tests passed." diff --git a/repl/pom.xml b/repl/pom.xml index fcc5f90d870e8..af528c8914335 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -99,6 +99,20 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.scalatest scalatest-maven-plugin diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index d9eeffa86016a..e56b74edba88c 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -15,15 +15,15 @@ import scala.tools.nsc._ import scala.tools.nsc.backend.JavaPlatform import scala.tools.nsc.interpreter._ -import scala.tools.nsc.interpreter.{ Results => IR } -import Predef.{ println => _, _ } -import java.io.{ BufferedReader, FileReader } +import scala.tools.nsc.interpreter.{Results => IR} +import Predef.{println => _, _} +import java.io.{BufferedReader, FileReader} import java.net.URI import java.util.concurrent.locks.ReentrantLock import scala.sys.process.Process import scala.tools.nsc.interpreter.session._ -import scala.util.Properties.{ jdkHome, javaVersion } -import scala.tools.util.{ Javap } +import scala.util.Properties.{jdkHome, javaVersion} +import scala.tools.util.{Javap} import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import scala.concurrent.ops @@ -33,7 +33,7 @@ import scala.tools.nsc.io.{File, Directory} import scala.reflect.NameTransformer._ import scala.tools.nsc.util.ScalaClassLoader._ import scala.tools.util._ -import scala.language.{implicitConversions, existentials} +import scala.language.{implicitConversions, existentials, postfixOps} import scala.reflect.{ClassTag, classTag} import scala.tools.reflect.StdRuntimeTags._ diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 76ba1ecca33ab..c54f8b72ebf42 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -140,5 +140,6 @@ + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4a9524074132e..574d96d92942b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -40,7 +40,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool // TODO: pass this in as a parameter. val fixedPoint = FixedPoint(100) - val batches: Seq[Batch] = Seq( + /** + * Override to provide additional rules for the "Resolution" batch. + */ + val extendedRules: Seq[Rule[LogicalPlan]] = Nil + + lazy val batches: Seq[Batch] = Seq( Batch("MultiInstanceRelations", Once, NewRelationInstances), Batch("CaseInsensitiveAttributeReferences", Once, @@ -54,8 +59,9 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool StarExpansion :: ResolveFunctions :: GlobalAggregates :: - UnresolvedHavingClauseAttributes :: - typeCoercionRules :_*), + UnresolvedHavingClauseAttributes :: + typeCoercionRules ++ + extendedRules : _*), Batch("Check Analysis", Once, CheckResolution), Batch("AnalysisOperators", fixedPoint, @@ -63,7 +69,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ) /** - * Makes sure all attributes have been resolved. + * Makes sure all attributes and logical plans have been resolved. */ object CheckResolution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { @@ -71,6 +77,13 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case p if p.expressions.exists(!_.resolved) => throw new TreeNodeException(p, s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}") + case p if !p.resolved && p.childrenResolved => + throw new TreeNodeException(p, "Unresolved plan found") + } match { + // As a backstop, use the root node to check that the entire plan tree is resolved. + case p if !p.resolved => + throw new TreeNodeException(p, "Unresolved plan in tree") + case p => p } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index bd8131c9af6e0..79e5283e86a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -286,6 +286,10 @@ trait HiveTypeCoercion { // If the data type is not boolean and is being cast boolean, turn it into a comparison // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) + // Stringify boolean if casting to StringType. + // TODO Ensure true/false string letter casing is consistent with Hive in all cases. + case Cast(e, StringType) if e.dataType == BooleanType => + If(e, Literal("true"), Literal("false")) // Turn true into 1, and false into 0 if casting boolean into other types. case Cast(e, dataType) if e.dataType == BooleanType => Cast(If(e, Literal(1), Literal(0)), dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 088f11ee4aa53..9cbab3d5d0d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -171,7 +171,7 @@ final class MutableByte extends MutableValue { } final class MutableAny extends MutableValue { - var value: Any = 0 + var value: Any = _ def boxed = if (isNull) null else value def update(v: Any) = value = { isNull = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ddd4b3755d629..a4133feae8166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,12 +40,60 @@ object Optimizer extends RuleExecutor[LogicalPlan] { SimplifyCasts, SimplifyCaseConversionExpressions) :: Batch("Filter Pushdown", FixedPoint(100), + UnionPushdown, CombineFilters, PushPredicateThroughProject, PushPredicateThroughJoin, ColumnPruning) :: Nil } +/** + * Pushes operations to either side of a Union. + */ +object UnionPushdown extends Rule[LogicalPlan] { + + /** + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + def buildRewrites(union: Union): AttributeMap[Attribute] = { + assert(union.left.output.size == union.right.output.size) + + AttributeMap(union.left.output.zip(union.right.output)) + } + + /** + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + val result = e transform { + case a: Attribute => rewrites(a) + } + + // We must promise the compiler that we did not discard the names in the case of project + // expressions. This is safe since the only transformation is from Attribute => Attribute. + result.asInstanceOf[A] + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Push down filter into union + case Filter(condition, u @ Union(left, right)) => + val rewrites = buildRewrites(u) + Union( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into union + case Project(projectList, u @ Union(left, right)) => + val rewrites = buildRewrites(u) + Union( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + } +} + + /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index bae491f07c13f..ede431ad4ab27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -58,7 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { /** * Returns true if this expression and all its children have been resolved to a specific schema - * and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan + * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan * can override this (e.g. * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] * should return `false`). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4adfb189372d6..5d10754c7b028 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -114,11 +114,12 @@ case class InsertIntoTable( } } -case class InsertIntoCreatedTable( +case class CreateTableAsSelect( databaseName: Option[String], tableName: String, child: LogicalPlan) extends UnaryNode { override def output = child.output + override lazy val resolved = (databaseName != None && childrenResolved) } case class WriteToFile( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 70c6d06cf2534..49520b7678e90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -308,13 +308,9 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) { object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - - private def validateFields(fields: Seq[StructField]): Boolean = - fields.map(field => field.name).distinct.size == fields.size } case class StructType(fields: Seq[StructField]) extends DataType { - require(StructType.validateFields(fields), "Found fields with the same name.") /** * Returns all field names in a [[Seq]]. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0a4fde3de7752..5809a108ff62e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -93,6 +93,17 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { val e = intercept[TreeNodeException[_]] { caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation)) } - assert(e.getMessage().toLowerCase.contains("unresolved")) + assert(e.getMessage().toLowerCase.contains("unresolved attribute")) + } + + test("throw errors for unresolved plans during analysis") { + case class UnresolvedTestPlan() extends LeafNode { + override lazy val resolved = false + override def output = Nil + } + val e = intercept[TreeNodeException[_]] { + caseSensitiveAnalyze(UnresolvedTestPlan()) + } + assert(e.getMessage().toLowerCase.contains("unresolved plan")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index ba8b853b6f99e..baeb9b0cf5964 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.types._ class HiveTypeCoercionSuite extends FunSuite { @@ -84,4 +86,17 @@ class HiveTypeCoercionSuite extends FunSuite { widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) } + + test("boolean casts") { + val booleanCasts = new HiveTypeCoercion { }.BooleanCasts + def ruleTest(initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + // Remove superflous boolean -> boolean casts. + ruleTest(Cast(Literal(true), BooleanType), Literal(true)) + // Stringify boolean when casting to string. + ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala new file mode 100644 index 0000000000000..dfef87bd9133d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class UnionPushdownSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateAnalysisOperators) :: + Batch("Union Pushdown", Once, + UnionPushdown) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testUnion = Union(testRelation, testRelation2) + + test("union: filter to each side") { + val query = testUnion.where('a === 1) + + val optimized = Optimize(query.analyze) + + val correctAnswer = + Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("union: project to each side") { + val query = testUnion.select('b) + + val optimized = Optimize(query.analyze) + + val correctAnswer = + Union(testRelation.select('b), testRelation2.select('e)).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a2f334aab9fdf..7dbaf7faff0c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -414,7 +414,7 @@ class SQLContext(@transient val sparkContext: SparkContext) def simpleString: String = s"""== Physical Plan == |${stringOrError(executedPlan)} - """ + """.stripMargin.trim override def toString: String = // TODO previously will output RDD details by run (${stringOrError(toRdd.toDebugString)}) @@ -460,7 +460,6 @@ class SQLContext(@transient val sparkContext: SparkContext) rdd: RDD[Array[Any]], schema: StructType): SchemaRDD = { import scala.collection.JavaConversions._ - import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true @@ -482,8 +481,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case (null, _) => null case (c: java.util.List[_], ArrayType(elementType, _)) => - val converted = c.map { e => convert(e, elementType)} - JListWrapper(converted) + c.map { e => convert(e, elementType)}: Seq[Any] case (c, ArrayType(elementType, _)) if c.getClass.isArray => c.asInstanceOf[Array[_]].map(e => convert(e, elementType)): Seq[Any] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d2ceb4a2b0b25..3bc5dce095511 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -377,15 +377,15 @@ class SchemaRDD( def toJavaSchemaRDD: JavaSchemaRDD = new JavaSchemaRDD(sqlContext, logicalPlan) /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + * Helper for converting a Row to a simple Array suitable for pyspark serialization. */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + private def rowToJArray(row: Row, structType: StructType): Array[Any] = { import scala.collection.Map def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (obj: Row, struct: StructType) => rowToArray(obj, struct) + case (obj: Row, struct: StructType) => rowToJArray(obj, struct) case (seq: Seq[Any], array: ArrayType) => seq.map(x => toJava(x, array.elementType)).asJava @@ -402,22 +402,37 @@ class SchemaRDD( case (other, _) => other } - def rowToArray(row: Row, structType: StructType): Array[Any] = { - val fields = structType.fields.map(field => field.dataType) - row.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray - } + val fields = structType.fields.map(field => field.dataType) + row.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray + } + /** + * Converts a JavaRDD to a PythonRDD. It is used by pyspark. + */ + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) this.mapPartitions { iter => val pickle = new Pickler iter.map { row => - rowToArray(row, rowSchema) + rowToJArray(row, rowSchema) }.grouped(100).map(batched => pickle.dumps(batched.toArray)) } } + /** + * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same + * format as javaToPython. It is used by pyspark. + */ + private[sql] def collectToPython: JList[Array[Byte]] = { + val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) + val pickle = new Pickler + new java.util.ArrayList(collect().map { row => + rowToJArray(row, rowSchema) + }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) + } + /** * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value * of base RDD functions that do not change schema. @@ -433,7 +448,7 @@ class SchemaRDD( } // ======================================================================= - // Overriden RDD actions + // Overridden RDD actions // ======================================================================= override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 2f3033a5f94f0..e52eeb3e1c47e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -54,7 +54,7 @@ private[sql] trait SchemaRDDLike { @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { // For various commands (like DDL) and queries with side effects, we force query optimization to // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: InsertIntoCreatedTable | _: WriteToFile => + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect |_: WriteToFile => queryExecution.toRdd SparkLogicalPlan(queryExecution.executedPlan)(sqlContext) case _ => @@ -124,7 +124,7 @@ private[sql] trait SchemaRDDLike { */ @Experimental def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd + sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan)).toRdd /** Returns the schema as a string in the tree format. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 4d799b4038fdd..e7faba0c7f620 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -112,6 +112,8 @@ class JavaSchemaRDD( new java.util.ArrayList(arr) } + override def count(): Long = baseSchemaRDD.count + override def take(num: Int): JList[Row] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[Row] = baseSchemaRDD.take(num).toSeq.map(new Row(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 42a5a9a84f362..c9faf0852142a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -50,11 +50,13 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( def hasNext = buffer.hasRemaining - def extractTo(row: MutableRow, ordinal: Int) { - columnType.setField(row, ordinal, extractSingle(buffer)) + def extractTo(row: MutableRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) } - def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer) + def extractSingle(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } protected def underlyingBuffer = buffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index b3ec5ded22422..2e61a981375aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -68,10 +68,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - override def appendFrom(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - buffer = ensureFreeSpace(buffer, columnType.actualSize(field)) - columnType.append(field, buffer) + override def appendFrom(row: Row, ordinal: Int): Unit = { + buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) + columnType.append(row, ordinal, buffer) } override def build() = { @@ -142,16 +141,16 @@ private[sql] object ColumnBuilder { useCompression: Boolean = false): ColumnBuilder = { val builder = (typeId match { - case INT.typeId => new IntColumnBuilder - case LONG.typeId => new LongColumnBuilder - case FLOAT.typeId => new FloatColumnBuilder - case DOUBLE.typeId => new DoubleColumnBuilder - case BOOLEAN.typeId => new BooleanColumnBuilder - case BYTE.typeId => new ByteColumnBuilder - case SHORT.typeId => new ShortColumnBuilder - case STRING.typeId => new StringColumnBuilder - case BINARY.typeId => new BinaryColumnBuilder - case GENERIC.typeId => new GenericColumnBuilder + case INT.typeId => new IntColumnBuilder + case LONG.typeId => new LongColumnBuilder + case FLOAT.typeId => new FloatColumnBuilder + case DOUBLE.typeId => new DoubleColumnBuilder + case BOOLEAN.typeId => new BooleanColumnBuilder + case BYTE.typeId => new ByteColumnBuilder + case SHORT.typeId => new ShortColumnBuilder + case STRING.typeId => new StringColumnBuilder + case BINARY.typeId => new BinaryColumnBuilder + case GENERIC.typeId => new GenericColumnBuilder case TIMESTAMP.typeId => new TimestampColumnBuilder }).asInstanceOf[ColumnBuilder] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index fc343ccb995c2..203a714e03c97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -69,7 +69,7 @@ private[sql] class ByteColumnStats extends ColumnStats { var lower = Byte.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) if (value > upper) upper = value @@ -87,7 +87,7 @@ private[sql] class ShortColumnStats extends ColumnStats { var lower = Short.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) if (value > upper) upper = value @@ -105,7 +105,7 @@ private[sql] class LongColumnStats extends ColumnStats { var lower = Long.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) if (value > upper) upper = value @@ -123,7 +123,7 @@ private[sql] class DoubleColumnStats extends ColumnStats { var lower = Double.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) if (value > upper) upper = value @@ -141,7 +141,7 @@ private[sql] class FloatColumnStats extends ColumnStats { var lower = Float.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) if (value > upper) upper = value @@ -159,7 +159,7 @@ private[sql] class IntColumnStats extends ColumnStats { var lower = Int.MaxValue var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) if (value > upper) upper = value @@ -177,7 +177,7 @@ private[sql] class StringColumnStats extends ColumnStats { var lower: String = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row.getString(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value @@ -195,7 +195,7 @@ private[sql] class TimestampColumnStats extends ColumnStats { var lower: Timestamp = null var nullCount = 0 - override def gatherStats(row: Row, ordinal: Int) { + override def gatherStats(row: Row, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { val value = row(ordinal).asInstanceOf[Timestamp] if (upper == null || value.compareTo(upper) > 0) upper = value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 9a61600115872..198b5756676aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag -import java.sql.Timestamp - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types._ @@ -46,16 +45,33 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def extract(buffer: ByteBuffer): JvmType + /** + * Extracts a value out of the buffer at the buffer's current position and stores in + * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever + * possible. + */ + def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + setField(row, ordinal, extract(buffer)) + } + /** * Appends the given value v of type T into the given ByteBuffer. */ - def append(v: JvmType, buffer: ByteBuffer) + def append(v: JvmType, buffer: ByteBuffer): Unit + + /** + * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this + * method to avoid boxing/unboxing costs whenever possible. + */ + def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + append(getField(row, ordinal), buffer) + } /** - * Returns the size of the value. This is used to calculate the size of variable length types - * such as byte arrays and strings. + * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable + * length types such as byte arrays and strings. */ - def actualSize(v: JvmType): Int = defaultSize + def actualSize(row: Row, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs @@ -67,7 +83,15 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing * costs whenever possible. */ - def setField(row: MutableRow, ordinal: Int, value: JvmType) + def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ + def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to(toOrdinal) = from(fromOrdinal) + } /** * Creates a duplicated copy of the value. @@ -90,119 +114,205 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { - def append(v: Int, buffer: ByteBuffer) { + def append(v: Int, buffer: ByteBuffer): Unit = { buffer.putInt(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(row.getInt(ordinal)) + } + def extract(buffer: ByteBuffer) = { buffer.getInt() } - override def setField(row: MutableRow, ordinal: Int, value: Int) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setInt(ordinal, buffer.getInt()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { row.setInt(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } } private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { - override def append(v: Long, buffer: ByteBuffer) { + override def append(v: Long, buffer: ByteBuffer): Unit = { buffer.putLong(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putLong(row.getLong(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getLong() } - override def setField(row: MutableRow, ordinal: Int, value: Long) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setLong(ordinal, buffer.getLong()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row.setLong(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } } private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { - override def append(v: Float, buffer: ByteBuffer) { + override def append(v: Float, buffer: ByteBuffer): Unit = { buffer.putFloat(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putFloat(row.getFloat(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getFloat() } - override def setField(row: MutableRow, ordinal: Int, value: Float) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setFloat(ordinal, buffer.getFloat()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { row.setFloat(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { - override def append(v: Double, buffer: ByteBuffer) { + override def append(v: Double, buffer: ByteBuffer): Unit = { buffer.putDouble(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putDouble(row.getDouble(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getDouble() } - override def setField(row: MutableRow, ordinal: Int, value: Double) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setDouble(ordinal, buffer.getDouble()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { row.setDouble(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { - override def append(v: Boolean, buffer: ByteBuffer) { - buffer.put(if (v) 1.toByte else 0.toByte) + override def append(v: Boolean, buffer: ByteBuffer): Unit = { + buffer.put(if (v) 1: Byte else 0: Byte) + } + + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } override def extract(buffer: ByteBuffer) = buffer.get() == 1 - override def setField(row: MutableRow, ordinal: Int, value: Boolean) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setBoolean(ordinal, buffer.get() == 1) + } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { row.setBoolean(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } } private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { - override def append(v: Byte, buffer: ByteBuffer) { + override def append(v: Byte, buffer: ByteBuffer): Unit = { buffer.put(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(row.getByte(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.get() } - override def setField(row: MutableRow, ordinal: Int, value: Byte) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setByte(ordinal, buffer.get()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { row.setByte(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } } private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { - override def append(v: Short, buffer: ByteBuffer) { + override def append(v: Short, buffer: ByteBuffer): Unit = { buffer.putShort(v) } + override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putShort(row.getShort(ordinal)) + } + override def extract(buffer: ByteBuffer) = { buffer.getShort() } - override def setField(row: MutableRow, ordinal: Int, value: Short) { + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setShort(ordinal, buffer.getShort()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { row.setShort(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(v: String): Int = v.getBytes("utf-8").length + 4 + override def actualSize(row: Row, ordinal: Int): Int = { + row.getString(ordinal).getBytes("utf-8").length + 4 + } - override def append(v: String, buffer: ByteBuffer) { + override def append(v: String, buffer: ByteBuffer): Unit = { val stringBytes = v.getBytes("utf-8") buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } @@ -214,11 +324,15 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { new String(stringBytes, "utf-8") } - override def setField(row: MutableRow, ordinal: Int, value: String) { + override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { row.setString(ordinal, value) } override def getField(row: Row, ordinal: Int) = row.getString(ordinal) + + override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + to.setString(toOrdinal, from.getString(fromOrdinal)) + } } private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { @@ -228,7 +342,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { timestamp } - override def append(v: Timestamp, buffer: ByteBuffer) { + override def append(v: Timestamp, buffer: ByteBuffer): Unit = { buffer.putLong(v.getTime).putInt(v.getNanos) } @@ -236,7 +350,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) { row(ordinal).asInstanceOf[Timestamp] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp) { + override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { row(ordinal) = value } } @@ -246,9 +360,11 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(v: Array[Byte]) = v.length + 4 + override def actualSize(row: Row, ordinal: Int) = { + getField(row, ordinal).length + 4 + } - override def append(v: Array[Byte], buffer: ByteBuffer) { + override def append(v: Array[Byte], buffer: ByteBuffer): Unit = { buffer.putInt(v.length).put(v, 0, v.length) } @@ -261,7 +377,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = value } @@ -272,7 +388,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) { // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) { - override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eab2f23c18e1..8a3612cdf19be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -52,7 +52,7 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { baseIterator => + val cached = child.execute().mapPartitions { rowIterator => new Iterator[CachedBatch] { def next() = { val columnBuilders = output.map { attribute => @@ -61,11 +61,9 @@ private[sql] case class InMemoryRelation( ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) }.toArray - var row: Row = null var rowCount = 0 - - while (baseIterator.hasNext && rowCount < batchSize) { - row = baseIterator.next() + while (rowIterator.hasNext && rowCount < batchSize) { + val row = rowIterator.next() var i = 0 while (i < row.length) { columnBuilders(i).appendFrom(row, i) @@ -80,7 +78,7 @@ private[sql] case class InMemoryRelation( CachedBatch(columnBuilders.map(_.build()), stats) } - def hasNext = baseIterator.hasNext + def hasNext = rowIterator.hasNext } }.cache() @@ -182,6 +180,7 @@ private[sql] case class InMemoryColumnarTableScan( } } + // Accumulators used for testing purposes val readPartitions = sparkContext.accumulator(0) val readBatches = sparkContext.accumulator(0) @@ -191,40 +190,36 @@ private[sql] case class InMemoryColumnarTableScan( readPartitions.setValue(0) readBatches.setValue(0) - relation.cachedColumnBuffers.mapPartitions { iterator => + relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( partitionFilters.reduceOption(And).getOrElse(Literal(true)), relation.partitionStatistics.schema) - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = if (attributes.isEmpty) { - Seq(0) + // Find the ordinals and data types of the requested columns. If none are requested, use the + // narrowest (the field with minimum default element size). + val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) { + val (narrowestOrdinal, narrowestDataType) = + relation.output.zipWithIndex.map { case (a, ordinal) => + ordinal -> a.dataType + } minBy { case (_, dataType) => + ColumnType(dataType).defaultSize + } + Seq(narrowestOrdinal) -> Seq(narrowestDataType) } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + attributes.map { a => + relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType + }.unzip } - val rows = iterator - // Skip pruned batches - .filter { cachedBatch => - if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) { - def statsString = relation.partitionStatistics.schema - .zip(cachedBatch.stats) - .map { case (a, s) => s"${a.name}: $s" } - .mkString(", ") - logInfo(s"Skipping partition based on stats $statsString") - false - } else { - readBatches += 1 - true - } - } - // Build column accessors - .map { cachedBatch => - requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) - } - // Extract rows via column accessors - .flatMap { columnAccessors => - val nextRow = new GenericMutableRow(columnAccessors.length) + val nextRow = new SpecificMutableRow(requestedColumnDataTypes) + + def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = { + val rows = cacheBatches.flatMap { cachedBatch => + // Build column accessors + val columnAccessors = + requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + + // Extract rows via column accessors new Iterator[Row] { override def next() = { var i = 0 @@ -235,15 +230,38 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors(0).hasNext } } - if (rows.hasNext) { - readPartitions += 1 + if (rows.hasNext) { + readPartitions += 1 + } + + rows } - rows + // Do partition batch pruning if enabled + val cachedBatchesToScan = + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter(cachedBatch.stats)) { + def statsString = relation.partitionStatistics.schema + .zip(cachedBatch.stats) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + readBatches += 1 + true + } + } + } else { + cachedBatchIterator + } + + cachedBatchesToRows(cachedBatchesToScan) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index b7f8826861a2c..965782a40031b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { private var nextNullIndex: Int = _ private var pos: Int = 0 - abstract override protected def initialize() { + abstract override protected def initialize(): Unit = { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) nullCount = nullsBuffer.getInt() nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 @@ -39,7 +39,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { super.initialize() } - abstract override def extractTo(row: MutableRow, ordinal: Int) { + abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = { if (pos == nextNullIndex) { seenNulls += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index a72970eef7aa4..f1f494ac26d0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -40,7 +40,11 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { protected var nullCount: Int = _ private var pos: Int = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + nulls = ByteBuffer.allocate(1024) nulls.order(ByteOrder.nativeOrder()) pos = 0 @@ -48,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { super.initialize(initialSize, columnName, useCompression) } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index b4120a3d4368b..27ac5f4dbdbbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar.compression -import java.nio.ByteBuffer - +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} @@ -34,5 +33,7 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc abstract override def hasNext = super.hasNext || decoder.hasNext - override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next() + override def extractSingle(row: MutableRow, ordinal: Int): Unit = { + decoder.next(row, ordinal) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index a5826bb033e41..628d9cec41d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -48,12 +48,16 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] var compressionEncoders: Seq[Encoder[T]] = _ - abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) { + abstract override def initialize( + initialSize: Int, + columnName: String, + useCompression: Boolean): Unit = { + compressionEncoders = if (useCompression) { - schemes.filter(_.supports(columnType)).map(_.encoder[T]) + schemes.filter(_.supports(columnType)).map(_.encoder[T](columnType)) } else { - Seq(PassThrough.encoder) + Seq(PassThrough.encoder(columnType)) } super.initialize(initialSize, columnName, useCompression) } @@ -62,17 +66,15 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] encoder.compressionRatio < 0.8 } - private def gatherCompressibilityStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - + private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { var i = 0 while (i < compressionEncoders.length) { - compressionEncoders(i).gatherCompressibilityStats(field, columnType) + compressionEncoders(i).gatherCompressibilityStats(row, ordinal) i += 1 } } - abstract override def appendFrom(row: Row, ordinal: Int) { + abstract override def appendFrom(row: Row, ordinal: Int): Unit = { super.appendFrom(row, ordinal) if (!row.isNullAt(ordinal)) { gatherCompressibilityStats(row, ordinal) @@ -84,7 +86,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] val typeId = nonNullBuffer.getInt() val encoder: Encoder[T] = { val candidate = compressionEncoders.minBy(_.compressionRatio) - if (isWorthCompressing(candidate)) candidate else PassThrough.encoder + if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType) } // Header = column type ID + null count + null positions @@ -104,7 +106,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] .putInt(nullCount) .put(nulls) - logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") - encoder.compress(nonNullBuffer, compressedBuffer, columnType) + logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(nonNullBuffer, compressedBuffer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 7797f75177893..acb06cb5376b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.columnar.compression -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} private[sql] trait Encoder[T <: NativeType] { - def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {} + def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} def compressedSize: Int @@ -33,17 +35,21 @@ private[sql] trait Encoder[T <: NativeType] { if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 } - def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer + def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType] +private[sql] trait Decoder[T <: NativeType] { + def next(row: MutableRow, ordinal: Int): Unit + + def hasNext: Boolean +} private[sql] trait CompressionScheme { def typeId: Int def supports(columnType: ColumnType[_, _]): Boolean - def encoder[T <: NativeType]: Encoder[T] + def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 8cf9ec74ca2de..29edcf17242c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -23,7 +23,8 @@ import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.runtimeMirror -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar._ import org.apache.spark.util.Utils @@ -33,18 +34,20 @@ private[sql] case object PassThrough extends CompressionScheme { override def supports(columnType: ColumnType[_, _]) = true - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { override def uncompressedSize = 0 override def compressedSize = 0 - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { // Writes compression type ID and copies raw contents to.putInt(PassThrough.typeId).put(from).rewind() to @@ -54,7 +57,9 @@ private[sql] case object PassThrough extends CompressionScheme { class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { - override def next() = columnType.extract(buffer) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } override def hasNext = buffer.hasRemaining } @@ -63,7 +68,9 @@ private[sql] case object PassThrough extends CompressionScheme { private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { new this.Decoder(buffer, columnType) @@ -74,24 +81,25 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { private var _uncompressedSize = 0 private var _compressedSize = 0 // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. - private val lastValue = new GenericMutableRow(1) + private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) private var lastRun = 0 override def uncompressedSize = _uncompressedSize override def compressedSize = _compressedSize - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { - val actualSize = columnType.actualSize(value) + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + val actualSize = columnType.actualSize(row, ordinal) _uncompressedSize += actualSize if (lastValue.isNullAt(0)) { - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 _compressedSize += actualSize + 4 } else { @@ -99,37 +107,40 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { lastRun += 1 } else { _compressedSize += actualSize + 4 - columnType.setField(lastValue, 0, value) + columnType.copyField(row, ordinal, lastValue, 0) lastRun = 1 } } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(RunLengthEncoding.typeId) if (from.hasRemaining) { - var currentValue = columnType.extract(from) + val currentValue = new SpecificMutableRow(Seq(columnType.dataType)) var currentRun = 1 + val value = new SpecificMutableRow(Seq(columnType.dataType)) + + columnType.extract(from, currentValue, 0) while (from.hasRemaining) { - val value = columnType.extract(from) + columnType.extract(from, value, 0) - if (value == currentValue) { + if (value.head == currentValue.head) { currentRun += 1 } else { // Writes current run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) // Resets current run - currentValue = value + columnType.copyField(value, 0, currentValue, 0) currentRun = 1 } } // Writes the last run - columnType.append(currentValue, to) + columnType.append(currentValue, 0, to) to.putInt(currentRun) } @@ -145,7 +156,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { private var valueCount = 0 private var currentValue: T#JvmType = _ - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { currentValue = columnType.extract(buffer) run = buffer.getInt() @@ -154,7 +165,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { valueCount += 1 } - currentValue + columnType.setField(row, ordinal, currentValue) } override def hasNext = valueCount < run || buffer.hasRemaining @@ -171,14 +182,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType] = new this.Encoder[T] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + new this.Encoder[T](columnType) + } override def supports(columnType: ColumnType[_, _]) = columnType match { case INT | LONG | STRING => true case _ => false } - class Encoder[T <: NativeType] extends compression.Encoder[T] { + class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary // overflows. private var _uncompressedSize = 0 @@ -200,9 +213,11 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // to store dictionary element count. private var dictionarySize = 4 - override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = columnType.getField(row, ordinal) + if (!overflow) { - val actualSize = columnType.actualSize(value) + val actualSize = columnType.actualSize(row, ordinal) count += 1 _uncompressedSize += actualSize @@ -221,7 +236,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + override def compress(from: ByteBuffer, to: ByteBuffer) = { if (overflow) { throw new IllegalStateException( "Dictionary encoding should not be used because of dictionary overflow.") @@ -264,7 +279,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { } } - override def next() = dictionary(buffer.getShort()) + override def next(row: MutableRow, ordinal: Int): Unit = { + columnType.setField(row, ordinal, dictionary(buffer.getShort())) + } override def hasNext = buffer.hasRemaining } @@ -279,25 +296,20 @@ private[sql] case object BooleanBitSet extends CompressionScheme { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new this.Encoder).asInstanceOf[compression.Encoder[T]] + } override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 - override def gatherCompressibilityStats( - value: Boolean, - columnType: NativeColumnType[BooleanType.type]) { - + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { _uncompressedSize += BOOLEAN.defaultSize } - override def compress( - from: ByteBuffer, - to: ByteBuffer, - columnType: NativeColumnType[BooleanType.type]) = { - + override def compress(from: ByteBuffer, to: ByteBuffer) = { to.putInt(BooleanBitSet.typeId) // Total element count (1 byte per Boolean value) .putInt(from.remaining) @@ -349,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private var visited: Int = 0 - override def next(): Boolean = { + override def next(row: MutableRow, ordinal: Int): Unit = { val bit = visited % BITS_PER_LONG visited += 1 @@ -357,123 +369,167 @@ private[sql] case object BooleanBitSet extends CompressionScheme { currentWord = buffer.getLong() } - ((currentWord >> bit) & 1) != 0 + row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0) } override def hasNext: Boolean = visited < count } } -private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme { +private[sql] case object IntDelta extends CompressionScheme { + override def typeId: Int = 4 + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { - new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]]) - .asInstanceOf[compression.Decoder[T]] + new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] - - /** - * Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or - * `(false, 0: Byte)` otherwise. - */ - protected def byteSizedDelta(x: I#JvmType, y: I#JvmType): (Boolean, Byte) + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - /** - * Simply computes `x + delta` - */ - protected def addDelta(x: I#JvmType, delta: Byte): I#JvmType + override def supports(columnType: ColumnType[_, _]) = columnType == INT - class Encoder extends compression.Encoder[I] { - private var _compressedSize: Int = 0 + class Encoder extends compression.Encoder[IntegerType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 - private var _uncompressedSize: Int = 0 + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize - private var prev: I#JvmType = _ + private var prevValue: Int = _ - private var initial = true + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getInt(ordinal) + val delta = value - prevValue - override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) { - _uncompressedSize += columnType.defaultSize + _compressedSize += 1 - if (initial) { - initial = false - _compressedSize += 1 + columnType.defaultSize - } else { - val (smallEnough, _) = byteSizedDelta(value, prev) - _compressedSize += (if (smallEnough) 1 else 1 + columnType.defaultSize) + // If this is the first integer to be compressed, or the delta is out of byte range, then give + // up compressing this integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += INT.defaultSize } - prev = value + _uncompressedSize += INT.defaultSize + prevValue = value } - override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = { + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { to.putInt(typeId) if (from.hasRemaining) { - var prev = columnType.extract(from) + var prev = from.getInt() to.put(Byte.MinValue) - columnType.append(prev, to) + to.putInt(prev) while (from.hasRemaining) { - val current = columnType.extract(from) - val (smallEnough, delta) = byteSizedDelta(current, prev) + val current = from.getInt() + val delta = current - prev prev = current - if (smallEnough) { - to.put(delta) + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) } else { to.put(Byte.MinValue) - columnType.append(current, to) + to.putInt(current) } } } - to.rewind() - to + to.rewind().asInstanceOf[ByteBuffer] } - - override def uncompressedSize = _uncompressedSize - - override def compressedSize = _compressedSize } - class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[I]) - extends compression.Decoder[I] { + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[IntegerType.type]) + extends compression.Decoder[IntegerType.type] { + + private var prev: Int = _ - private var prev: I#JvmType = _ + override def hasNext: Boolean = buffer.hasRemaining - override def next() = { + override def next(row: MutableRow, ordinal: Int): Unit = { val delta = buffer.get() - prev = if (delta > Byte.MinValue) addDelta(prev, delta) else columnType.extract(buffer) - prev + prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt() + row.setInt(ordinal, prev) } - - override def hasNext = buffer.hasRemaining } } -private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] { - override val typeId = 4 +private[sql] case object LongDelta extends CompressionScheme { + override def typeId: Int = 5 - override def supports(columnType: ColumnType[_, _]) = columnType == INT + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] + } + + override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = { + (new Encoder).asInstanceOf[compression.Encoder[T]] + } - override protected def addDelta(x: Int, delta: Byte) = x + delta + override def supports(columnType: ColumnType[_, _]) = columnType == LONG + + class Encoder extends compression.Encoder[LongType.type] { + protected var _compressedSize: Int = 0 + protected var _uncompressedSize: Int = 0 + + override def compressedSize = _compressedSize + override def uncompressedSize = _uncompressedSize + + private var prevValue: Long = _ + + override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = { + val value = row.getLong(ordinal) + val delta = value - prevValue + + _compressedSize += 1 - override protected def byteSizedDelta(x: Int, y: Int): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + // If this is the first long integer to be compressed, or the delta is out of byte range, then + // give up compressing this long integer. + if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { + _compressedSize += LONG.defaultSize + } + + _uncompressedSize += LONG.defaultSize + prevValue = value + } + + override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { + to.putInt(typeId) + + if (from.hasRemaining) { + var prev = from.getLong() + to.put(Byte.MinValue) + to.putLong(prev) + + while (from.hasRemaining) { + val current = from.getLong() + val delta = current - prev + prev = current + + if (Byte.MinValue < delta && delta <= Byte.MaxValue) { + to.put(delta.toByte) + } else { + to.put(Byte.MinValue) + to.putLong(current) + } + } + } + + to.rewind().asInstanceOf[ByteBuffer] + } } -} -private[sql] case object LongDelta extends IntegralDelta[LongType.type] { - override val typeId = 5 + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[LongType.type]) + extends compression.Decoder[LongType.type] { - override def supports(columnType: ColumnType[_, _]) = columnType == LONG + private var prev: Long = _ - override protected def addDelta(x: Long, delta: Byte) = x + delta + override def hasNext: Boolean = buffer.hasRemaining - override protected def byteSizedDelta(x: Long, y: Long): (Boolean, Byte) = { - val delta = x - y - if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + override def next(row: MutableRow, ordinal: Int): Unit = { + val delta = buffer.get() + prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong() + row.setLong(ordinal, prev) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 70062eae3b7ce..0f27fd13e7379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -68,8 +68,15 @@ private[sql] object JsonRDD extends Logging { val (topLevel, structLike) = values.partition(_.size == 1) val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { - case ArrayType(StructType(Nil), _) => false - case ArrayType(_, _) => true + case ArrayType(elementType, _) => { + def hasInnerStruct(t: DataType): Boolean = t match { + case s: StructType => false + case ArrayType(t1, _) => hasInnerStruct(t1) + case o => true + } + + hasInnerStruct(elementType) + } case struct: StructType => false case _ => true } @@ -84,7 +91,18 @@ private[sql] object JsonRDD extends Logging { val dataType = resolved.get(prefix :+ name).get dataType match { case array: ArrayType => - Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true)) + // The pattern of this array is ArrayType(...(ArrayType(StructType))). + // Since the inner struct of array is a placeholder (StructType(Nil)), + // we need to replace this placeholder with the actual StructType (structType). + def getActualArrayType( + innerStruct: StructType, + currentArray: ArrayType): ArrayType = currentArray match { + case ArrayType(s: StructType, containsNull) => + ArrayType(innerStruct, containsNull) + case ArrayType(a: ArrayType, containsNull) => + ArrayType(getActualArrayType(innerStruct, a), containsNull) + } + Some(StructField(name, getActualArrayType(structType, array), nullable = true)) case struct: StructType => Some(StructField(name, structType, nullable = true)) // dataType is StringType means that we have resolved type conflicts involving // primitive types and complex types. So, the type of name has been relaxed to @@ -168,8 +186,7 @@ private[sql] object JsonRDD extends Logging { /** * Returns the element type of an JSON array. We go through all elements of this array * to detect any possible type conflict. We use [[compatibleType]] to resolve - * type conflicts. Right now, when the element of an array is another array, we - * treat the element as String. + * type conflicts. */ private def typeOfArray(l: Seq[Any]): ArrayType = { val containsNull = l.exists(v => v == null) @@ -216,18 +233,24 @@ private[sql] object JsonRDD extends Logging { } case (key: String, array: Seq[_]) => { // The value associated with the key is an array. - typeOfArray(array) match { + // Handle inner structs of an array. + def buildKeyPathForInnerStructs(v: Any, t: DataType): Seq[(String, DataType)] = t match { case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. - array.asInstanceOf[Seq[Map[String, Any]]].flatMap { + v.asInstanceOf[Seq[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { - case (k, dataType) => (s"$key.$k", dataType) - } :+ (key, ArrayType(StructType(Nil), containsNull)) + case (k, t) => (s"$key.$k", t) + } } - case ArrayType(elementType, containsNull) => - (key, ArrayType(elementType, containsNull)) :: Nil + case ArrayType(t1, containsNull) => + v.asInstanceOf[Seq[Any]].flatMap { + element => buildKeyPathForInnerStructs(element, t1) + } + case other => Nil } + val elementType = typeOfArray(array) + buildKeyPathForInnerStructs(array, elementType) :+ (key, elementType) } case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil } @@ -264,9 +287,13 @@ private[sql] object JsonRDD extends Logging { // the ObjectMapper will take the last value associated with this duplicate key. // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() - iter.map { record => - val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) - parsed.asInstanceOf[Map[String, Any]] + iter.flatMap { record => + val parsed = mapper.readValue(record, classOf[Object]) match { + case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil + case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]] + } + + parsed } }) } @@ -339,8 +366,6 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case ArrayType(elementType, _) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case StringType => toString(value) case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) @@ -348,6 +373,10 @@ private[sql] object JsonRDD extends Logging { case DecimalType => toDecimal(value) case BooleanType => value.asInstanceOf[BooleanType.JvmType] case NullType => null + + case ArrayType(elementType, _) => + value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) } } } @@ -356,22 +385,9 @@ private[sql] object JsonRDD extends Logging { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { - // StructType - case (StructField(name, fields: StructType, _), i) => - row.update(i, json.get(name).flatMap(v => Option(v)).map( - v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) - - // ArrayType(StructType) - case (StructField(name, ArrayType(structType: StructType, _), _), i) => - row.update(i, - json.get(name).flatMap(v => Option(v)).map( - v => v.asInstanceOf[Seq[Any]].map( - e => asRow(e.asInstanceOf[Map[String, Any]], structType))).orNull) - - // Other cases case (StructField(name, dataType, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).getOrElse(null)) + enforceCorrectType(_, dataType)).orNull) } row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f2389f8f0591e..265b67737c475 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,8 +18,13 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SQLConf, SQLContext} /** A SQLContext that can be used for local testing. */ object TestSQLContext - extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends SQLContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) { + + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 514ac543df92a..67563b6c55f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll @@ -477,18 +478,48 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { (3, null))) } - test("EXCEPT") { + test("UNION") { + checkAnswer( + sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), + (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: + (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + checkAnswer( + sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), + (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil) + checkAnswer( + sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), + (1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") :: + (4, "d") :: (4, "d") :: Nil) + } + test("UNION with column mismatches") { + // Column name mismatches are allowed. + checkAnswer( + sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), + (1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") :: + (4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil) + // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( - sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData "), + sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), + ("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_))) + // Column type mismatches where a coercion is not possible, in this case between integer + // and array types, trigger a TreeNodeException. + intercept[TreeNodeException[_]] { + sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect() + } + } + + test("EXCEPT") { + checkAnswer( + sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), (1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil) checkAnswer( - sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData "), Nil) + sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( - sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData "), Nil) + sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) } test("INTERSECT") { @@ -634,6 +665,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT key, value FROM testData WHERE key BETWEEN 9 and 7"), Seq() ) + } + test("cast boolean to string") { + // TODO Ensure true/false string letter casing is consistent with Hive in all cases. + checkAnswer( + sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), + ("true", "false") :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index cde91ceb68c98..0cdbb3167ce36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -35,7 +35,7 @@ class ColumnStatsSuite extends FunSuite { def testColumnStats[T <: NativeType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: Row) { + initialStatistics: Row): Unit = { val columnStatsName = columnStatsClass.getSimpleName diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 75f653f3280bd..4fb1ecf1d532b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -46,10 +47,12 @@ class ColumnTypeSuite extends FunSuite with Logging { def checkActualSize[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], value: JvmType, - expected: Int) { + expected: Int): Unit = { assertResult(expected, s"Wrong actualSize for $columnType") { - columnType.actualSize(value) + val row = new GenericMutableRow(1) + columnType.setField(row, 0, value) + columnType.actualSize(row, 0) } } @@ -147,7 +150,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType) { + getter: (ByteBuffer) => T#JvmType): Unit = { testColumnType[T, T#JvmType](columnType, putter, getter) } @@ -155,7 +158,7 @@ class ColumnTypeSuite extends FunSuite with Logging { def testColumnType[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], putter: (ByteBuffer, JvmType) => Unit, - getter: (ByteBuffer) => JvmType) { + getter: (ByteBuffer) => JvmType): Unit = { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) val seq = (0 until 4).map(_ => makeRandomValue(columnType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 0e3c67f5eed29..c1278248ef655 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{SQLConf, QueryTest, TestData} +import org.apache.spark.sql.{QueryTest, TestData} class InMemoryColumnarQuerySuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 3baa6f8ec0c83..6c9a9ab6c3418 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -45,7 +45,9 @@ class NullableColumnAccessorSuite extends FunSuite { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnAccessor[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index a77262534a352..f54a21eb4fbb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -41,7 +41,9 @@ class NullableColumnBuilderSuite extends FunSuite { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { + def testNullableColumnBuilder[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]): Unit = { + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 5d2fd4959197c..69e0adbd3ee0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -28,7 +28,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be val originalColumnBatchSize = columnBatchSize val originalInMemoryPartitionPruning = inMemoryPartitionPruning - override protected def beforeAll() { + override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch setConf(SQLConf.COLUMN_BATCH_SIZE, "10") val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) @@ -38,7 +38,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") } - override protected def afterAll() { + override protected def afterAll(): Unit = { setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } @@ -76,7 +76,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be filter: String, expectedQueryResult: Seq[Int], expectedReadPartitions: Int, - expectedReadBatches: Int) { + expectedReadBatches: Int): Unit = { test(filter) { val query = sql(s"SELECT * FROM intData WHERE $filter") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index e01cc8b4d20f2..d9e488e0ffd16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -72,10 +73,14 @@ class BooleanBitSetSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val mutableRow = new GenericMutableRow(1) if (values.nonEmpty) { values.foreach { assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + mutableRow.getBoolean(0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index d2969d906c943..1cdb909146d57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -67,7 +68,7 @@ class DictionaryEncodingSuite extends FunSuite { val buffer = builder.build() val headerSize = CompressionScheme.columnHeaderSize(buffer) // 4 extra bytes for dictionary size - val dictionarySize = 4 + values.map(columnType.actualSize).sum + val dictionarySize = 4 + rows.map(columnType.actualSize(_, 0)).sum // 2 bytes for each `Short` val compressedSize = 4 + dictionarySize + 2 * inputSeq.length // 4 extra bytes for compression scheme type ID @@ -97,11 +98,15 @@ class DictionaryEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = DictionaryEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 322f447c24840..73f31c0233343 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -31,7 +31,7 @@ class IntegralDeltaSuite extends FunSuite { def testIntegralDelta[I <: IntegralType]( columnStats: ColumnStats, columnType: NativeColumnType[I], - scheme: IntegralDelta[I]) { + scheme: CompressionScheme) { def skeleton(input: Seq[I#JvmType]) { // ------------- @@ -96,10 +96,15 @@ class IntegralDeltaSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = scheme.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) + if (input.nonEmpty) { input.foreach{ assert(decoder.hasNext) - assertResult(_, "Wrong decoded value")(decoder.next()) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } assert(!decoder.hasNext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 218c09ac26362..4ce2552112c92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ @@ -57,7 +58,7 @@ class RunLengthEncodingSuite extends FunSuite { // Compression scheme ID + compressed contents val compressedSize = 4 + inputRuns.map { case (index, _) => // 4 extra bytes each run for run length - columnType.actualSize(values(index)) + 4 + columnType.actualSize(rows(index), 0) + 4 }.sum // 4 extra bytes for compression scheme type ID @@ -80,11 +81,15 @@ class RunLengthEncodingSuite extends FunSuite { buffer.rewind().position(headerSize + 4) val decoder = RunLengthEncoding.decoder(buffer, columnType) + val mutableRow = new GenericMutableRow(1) if (inputSeq.nonEmpty) { inputSeq.foreach { i => assert(decoder.hasNext) - assertResult(values(i), "Wrong decoded value")(decoder.next()) + assertResult(values(i), "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 301d482d27d86..685e788207725 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -591,8 +591,52 @@ class JsonSuite extends QueryTest { (true, "str1") :: Nil ) checkAnswer( - sql("select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] from jsonTable"), + sql( + """ + |select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] + |from jsonTable + """.stripMargin), ("str2", 6) :: Nil ) } + + test("SPARK-3390 Complex arrays") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType2) + jsonSchemaRDD.registerTempTable("jsonTable") + + checkAnswer( + sql( + """ + |select arrayOfArray1[0][0][0], arrayOfArray1[1][0][1], arrayOfArray1[1][1][0] + |from jsonTable + """.stripMargin), + (5, 7, 8) :: Nil + ) + checkAnswer( + sql( + """ + |select arrayOfArray2[0][0][0].inner1, arrayOfArray2[1][0], + |arrayOfArray2[1][1][1].inner2[0], arrayOfArray2[2][0][0].inner3[0][0].inner4 + |from jsonTable + """.stripMargin), + ("str1", Nil, "str4", 2) :: Nil + ) + } + + test("SPARK-3308 Read top level JSON arrays") { + val jsonSchemaRDD = jsonRDD(jsonArray) + jsonSchemaRDD.registerTempTable("jsonTable") + + checkAnswer( + sql( + """ + |select a, b, c + |from jsonTable + """.stripMargin), + ("str_a_1", null, null) :: + ("str_a_2", null, null) :: + (null, "str_b_3", null) :: + ("str_a_4", "str_b_4", "str_c_4") ::Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index b3f95f08e8044..fc833b8b54e4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -106,6 +106,41 @@ object TestJsonData { "inner1": "str4" }], "field2": [[5, 6], [7, 8]] - }] + }], + "arrayOfArray1": [ + [ + [5] + ], + [ + [6, 7], + [8] + ]], + "arrayOfArray2": [ + [ + [ + { + "inner1": "str1" + } + ] + ], + [ + [], + [ + {"inner2": ["str3", "str33"]}, + {"inner2": ["str4"], "inner1": "str11"} + ] + ], + [ + [ + {"inner3": [[{"inner4": 2}]]} + ] + ]] }""" :: Nil) + + val jsonArray = + TestSQLContext.sparkContext.parallelize( + """[{"a":"str_a_1"}]""" :: + """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: + """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: + """[]""" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b0a06cd3ca090..08f7358446b29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -58,8 +58,7 @@ case class AllDataTypes( doubleField: Double, shortField: Short, byteField: Byte, - booleanField: Boolean, - binaryField: Array[Byte]) + booleanField: Boolean) case class AllDataTypesWithNonPrimitiveType( stringField: String, @@ -70,13 +69,14 @@ case class AllDataTypesWithNonPrimitiveType( shortField: Short, byteField: Byte, booleanField: Boolean, - binaryField: Array[Byte], array: Seq[Int], arrayContainsNull: Seq[Option[Int]], map: Map[Int, Long], mapValueContainsNull: Map[Int, Option[Long]], data: Data) +case class BinaryData(binaryData: Array[Byte]) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -108,26 +108,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) - .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray)) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - } + val data = sparkContext.parallelize(range) + .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0)) + + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } - test("Treat binary as string") { + test("read/write binary data") { + // Since equality for Array[Byte] is broken we test this separately. + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir) + parquetFile(tempDir) + .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8")) + .collect().toSeq == Seq("test") + } + + ignore("Treat binary as string") { val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString // Create the test file. @@ -142,37 +142,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA StructField("c2", BinaryType, false) :: Nil) val schemaRDD1 = applySchema(rowRDD, schema) schemaRDD1.saveAsParquetFile(path) - val resultWithBinary = parquetFile(path).collect - range.foreach { - i => - assert(resultWithBinary(i).getInt(0) === i) - assert(resultWithBinary(i)(1) === s"val_$i".getBytes) - } - - TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") - // This ParquetRelation always use Parquet types to derive output. - val parquetRelation = new ParquetRelation( - path.toString, - Some(TestSQLContext.sparkContext.hadoopConfiguration), - TestSQLContext) { - override val output = - ParquetTypesConverter.convertToAttributes( - ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, - TestSQLContext.isParquetBinaryAsString) - } - val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) - val resultWithString = schemaRDD.collect - range.foreach { - i => - assert(resultWithString(i).getInt(0) === i) - assert(resultWithString(i)(1) === s"val_$i") - } + checkAnswer( + parquetFile(path).select('c1, 'c2.cast(StringType)), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) - schemaRDD.registerTempTable("tmp") + setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + parquetFile(path).printSchema() checkAnswer( - sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), - (5, "val_5") :: - (7, "val_7") :: Nil) + parquetFile(path), + schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq) + // Set it back. TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) @@ -275,34 +254,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) - TestSQLContext.sparkContext.parallelize(range) + val data = sparkContext.parallelize(range) .map(x => AllDataTypesWithNonPrimitiveType( s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0, - (0 to x).map(_.toByte).toArray, (0 until x), (0 until x).map(Option(_).filter(_ % 3 == 0)), (0 until x).map(i => i -> i.toLong).toMap, (0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None), Data((0 until x), Nested(x, s"$x")))) - .saveAsParquetFile(tempDir) - val result = parquetFile(tempDir).collect() - range.foreach { - i => - assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}") - assert(result(i).getInt(1) === i) - assert(result(i).getLong(2) === i.toLong) - assert(result(i).getFloat(3) === i.toFloat) - assert(result(i).getDouble(4) === i.toDouble) - assert(result(i).getShort(5) === i.toShort) - assert(result(i).getByte(6) === i.toByte) - assert(result(i).getBoolean(7) === (i % 2 == 0)) - assert(result(i)(8) === (0 to i).map(_.toByte).toArray) - assert(result(i)(9) === (0 until i)) - assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null)) - assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap) - assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null)) - assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i"))))) - } + data.saveAsParquetFile(tempDir) + + checkAnswer( + parquetFile(tempDir), + data.toSchemaRDD.collect().toSeq) } test("self-join parquet files") { @@ -399,23 +363,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } - test("Saving case class RDD table to file and reading it back in") { - val file = getTempFilePath("parquet") - val path = file.toString - val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) - .map(i => TestRDDEntry(i, s"val_$i")) - rdd.saveAsParquetFile(path) - val readFile = parquetFile(path) - readFile.registerTempTable("tmpx") - val rdd_copy = sql("SELECT * FROM tmpx").collect() - val rdd_orig = rdd.collect() - for(i <- 0 to 99) { - assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") - assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i") - } - Utils.deleteRecursively(file) - } - test("Read a parquet file instead of a directory") { val file = getTempFilePath("parquet") val path = file.toString @@ -448,32 +395,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) - assert(rdd_copy1(0).apply(0) === 1) - assert(rdd_copy1(0).apply(1) === "val_1") - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! + sql("INSERT INTO dest SELECT * FROM source") - val rdd_copy2 = sql("SELECT * FROM dest").collect() + val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0)) assert(rdd_copy2.size === 200) - assert(rdd_copy2(0).apply(0) === 1) - assert(rdd_copy2(0).apply(1) === "val_1") - assert(rdd_copy2(99).apply(0) === 100) - assert(rdd_copy2(99).apply(1) === "val_100") - assert(rdd_copy2(100).apply(0) === 1) - assert(rdd_copy2(100).apply(1) === "val_1") Utils.deleteRecursively(dirname) } test("Insert (appending) to same table via Scala API") { - // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is - // executed twice otherwise?! sql("INSERT INTO testsource SELECT * FROM testsource") val double_rdd = sql("SELECT * FROM testsource").collect() assert(double_rdd != null) assert(double_rdd.size === 30) - for(i <- (0 to 14)) { - assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match") - } + // let's restore the original test data Utils.deleteRecursively(ParquetTestData.testDir) ParquetTestData.writeFile() diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 45a4c6dc98da0..9d7a02bf7b0b7 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -95,6 +95,15 @@ org.apache.avro avro + ${avro.version} + + + + org.apache.avro + avro-mapred + ${avro.version} + ${avro.mapred.classifier} org.scalatest diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ced8397972fbd..e0be09e6793ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -262,7 +262,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = - new Analyzer(catalog, functionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) { + override val extendedRules = + catalog.CreateTables :: + catalog.PreInsertionCasts :: + ExtractPythonUdfs :: + Nil + } /** * Runs the specified SQL query using Hive. @@ -353,9 +359,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - // TODO: Create mixin for the analyzer instead of overriding things here. - override lazy val optimizedPlan = - optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))) override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6571c35499ef4..2c0db9be57e54 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -54,8 +54,8 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with db: Option[String], tableName: String, alias: Option[String]): LogicalPlan = synchronized { - val (dbName, tblName) = processDatabaseAndTableName(db, tableName) - val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) + val (databaseName, tblName) = processDatabaseAndTableName( + db.getOrElse(hive.sessionState.getCurrentDatabase), tableName) val table = client.getTable(databaseName, tblName) val partitions: Seq[Partition] = if (table.isPartitioned) { @@ -109,18 +109,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with */ object CreateTables extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case InsertIntoCreatedTable(db, tableName, child) => + // Wait until children are resolved. + case p: LogicalPlan if !p.childrenResolved => p + + case CreateTableAsSelect(db, tableName, child) => val (dbName, tblName) = processDatabaseAndTableName(db, tableName) val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase) - createTable(databaseName, tblName, child.output) - - InsertIntoTable( - EliminateAnalysisOperators( - lookupRelation(Some(databaseName), tblName, None)), - Map.empty, - child, - overwrite = false) + CreateTableAsSelect(Some(databaseName), tableName, child) } } @@ -130,15 +126,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with */ object PreInsertionCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transform { - // Wait until children are resolved + // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) => + case p @ InsertIntoTable( + LowerCaseSchema(table: MetastoreRelation), _, child, _) => castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, _, - HiveTableScan(_, table, _)), _, child, _) => + LowerCaseSchema( + InMemoryRelation(_, _, _, + HiveTableScan(_, table, _))), _, child, _) => castChildOutput(p, table, child) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c98287c6aa662..21ecf17028dbc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -489,7 +489,7 @@ private[hive] object HiveQl { val (db, tableName) = extractDbNameTableName(tableNameParts) - InsertIntoCreatedTable(db, tableName, nodeToPlan(query)) + CreateTableAsSelect(db, tableName, nodeToPlan(query)) // If its not a "CREATE TABLE AS" like above then just pass it back to hive as a native command. case Token("TOK_CREATETABLE", _) => NativePlaceholder diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 72cc01cdf4c84..43dd3d234f73a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -165,6 +165,16 @@ private[hive] trait HiveStrategies { InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil + case logical.CreateTableAsSelect(database, tableName, child) => + val query = planLater(child) + CreateTableAsSelect( + database.get, + tableName, + query, + InsertIntoHiveTable(_: MetastoreRelation, + Map(), + query, + true)(hiveContext)) :: Nil case _ => Nil } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 329f80cad471e..84fafcde63d05 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -25,16 +25,14 @@ import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.hive.serde2.Deserializer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector - +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} import org.apache.spark.SerializableWritable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast} -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.expressions._ /** * A trait for subclasses that handle table scans. @@ -108,12 +106,12 @@ class HadoopTableReader( val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) val attrsWithIndex = attributes.zipWithIndex - val mutableRow = new GenericMutableRow(attrsWithIndex.length) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = deserializerClass.newInstance() deserializer.initialize(hconf, tableDesc.getProperties) - HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow) } @@ -164,33 +162,32 @@ class HadoopTableReader( val tableDesc = relation.tableDesc val broadcastedHiveConf = _broadcastedHiveConf val localDeserializer = partDeserializer - val mutableRow = new GenericMutableRow(attributes.length) - - // split the attributes (output schema) into 2 categories: - // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the - // index of the attribute in the output Row. - val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => { - relation.partitionKeys.indexOf(attr._1) >= 0 - }) - - def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = { - partitionKeys.foreach { case (attr, ordinal) => - // get partition key ordinal for a given attribute - val partOridinal = relation.partitionKeys.indexOf(attr) - row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null) + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + + // Splits all attributes into two groups, partition key attributes and those that are not. + // Attached indices indicate the position of each attribute in the output schema. + val (partitionKeyAttrs, nonPartitionKeyAttrs) = + attributes.zipWithIndex.partition { case (attr, _) => + relation.partitionKeys.contains(attr) + } + + def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = { + partitionKeyAttrs.foreach { case (attr, ordinal) => + val partOrdinal = relation.partitionKeys.indexOf(attr) + row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null) } } - // fill the partition key for the given MutableRow Object + + // Fill all partition keys to the given MutableRow object fillPartitionKeys(partValues, mutableRow) - val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) - hivePartitionRDD.mapPartitions { iter => + createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter => val hconf = broadcastedHiveConf.value.value val deserializer = localDeserializer.newInstance() deserializer.initialize(hconf, partProps) - // fill the non partition key attributes - HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow) + // fill the non partition key attributes + HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow) } }.toSeq @@ -257,38 +254,64 @@ private[hive] object HadoopTableReader extends HiveInspectors { } /** - * Transform the raw data(Writable object) into the Row object for an iterable input - * @param iter Iterable input which represented as Writable object - * @param deserializer Deserializer associated with the input writable object - * @param attrs Represents the row attribute names and its zero-based position in the MutableRow - * @param row reusable MutableRow object - * - * @return Iterable Row object that transformed from the given iterable input. + * Transform all given raw `Writable`s into `Row`s. + * + * @param iterator Iterator of all `Writable`s to be transformed + * @param deserializer The `Deserializer` associated with the input `Writable` + * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding + * positions in the output schema + * @param mutableRow A reusable `MutableRow` that should be filled + * @return An `Iterator[Row]` transformed from `iterator` */ def fillObject( - iter: Iterator[Writable], + iterator: Iterator[Writable], deserializer: Deserializer, - attrs: Seq[(Attribute, Int)], - row: GenericMutableRow): Iterator[Row] = { + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector] - // get the field references according to the attributes(output of the reader) required - val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) } + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => + soi.getStructFieldRef(attr.name) -> ordinal + }.unzip + + // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern + // matching and branching costs per row. + val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map { + _.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi) + } + } // Map each tuple to a row object - iter.map { value => + iterator.map { value => val raw = deserializer.deserialize(value) - var idx = 0; - while (idx < fieldRefs.length) { - val fieldRef = fieldRefs(idx)._1 - val fieldIdx = fieldRefs(idx)._2 - val fieldValue = soi.getStructFieldData(raw, fieldRef) - - row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector()) - - idx += 1 + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 } - row: Row + mutableRow: Row } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index a013f3f7a805f..70fb15259e7d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -35,12 +35,13 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.hive._ +import org.apache.spark.sql.SQLConf /* Implicit conversions */ import scala.collection.JavaConversions._ object TestHive - extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf())) + extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) /** * A locally running test instance of Spark's Hive execution engine. @@ -90,6 +91,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + /** Fewer partitions to speed up testing. */ + override private[spark] def numShufflePartitions: Int = + getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking * to ensure it exists @@ -269,7 +274,74 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { |) """.stripMargin.cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd - ) + ), + // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING + // IS NOT YET SUPPORTED + TestTable("episodes_part", + s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) + |PARTITIONED BY (doctor_pt INT) + |ROW FORMAT SERDE '${classOf[AvroSerDe].getCanonicalName}' + |STORED AS + |INPUTFORMAT '${classOf[AvroContainerInputFormat].getCanonicalName}' + |OUTPUTFORMAT '${classOf[AvroContainerOutputFormat].getCanonicalName}' + |TBLPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + // WORKAROUND: Required to pass schema to SerDe for partitioned tables. + // TODO: Pass this automatically from the table to partitions. + s""" + |ALTER TABLE episodes_part SET SERDEPROPERTIES ( + | 'avro.schema.literal'='{ + | "type": "record", + | "name": "episodes", + | "namespace": "testing.hive.avro.serde", + | "fields": [ + | { + | "name": "title", + | "type": "string", + | "doc": "episode title" + | }, + | { + | "name": "air_date", + | "type": "string", + | "doc": "initial date" + | }, + | { + | "name": "doctor", + | "type": "int", + | "doc": "main actor playing the Doctor in episode" + | } + | ] + | }' + |) + """.stripMargin.cmd, + s""" + INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1) + SELECT title, air_date, doctor FROM episodes + """.cmd + ) ) hiveQTestUtilTables.foreach(registerTestTable) @@ -309,15 +381,6 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } - // It is important that we RESET first as broken hooks that might have been set could break - // other sql exec here. - runSqlHive("RESET") - // For some reason, RESET does not reset the following variables... - runSqlHive("set datanucleus.cache.collections=true") - runSqlHive("set datanucleus.cache.collections.lazy=true") - // Lots of tests fail if we do not change the partition whitelist from the default. - runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") - loadedTables.clear() catalog.client.getAllTables("default").foreach { t => logDebug(s"Deleting table $t") @@ -343,6 +406,14 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { FunctionRegistry.unregisterTemporaryUDF(udfName) } + // It is important that we RESET first as broken hooks that might have been set could break + // other sql exec here. + runSqlHive("RESET") + // For some reason, RESET does not reset the following variables... + runSqlHive("set datanucleus.cache.collections=true") + runSqlHive("set datanucleus.cache.collections.lazy=true") + // Lots of tests fail if we do not change the partition whitelist from the default. + runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") configure() runSqlHive("USE default") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala new file mode 100644 index 0000000000000..71ea774d77795 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LowerCaseSchema +import org.apache.spark.sql.execution.{SparkPlan, Command, LeafNode} +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.MetastoreRelation + +/** + * :: Experimental :: + * Create table and insert the query result into it. + * @param database the database name of the new relation + * @param tableName the table name of the new relation + * @param insertIntoRelation function of creating the `InsertIntoHiveTable` + * by specifying the `MetaStoreRelation`, the data will be inserted into that table. + * TODO Add more table creating properties, e.g. SerDe, StorageHandler, in-memory cache etc. + */ +@Experimental +case class CreateTableAsSelect( + database: String, + tableName: String, + query: SparkPlan, + insertIntoRelation: MetastoreRelation => InsertIntoHiveTable) + extends LeafNode with Command { + + def output = Seq.empty + + // A lazy computing of the metastoreRelation + private[this] lazy val metastoreRelation: MetastoreRelation = { + // Create the table + val sc = sqlContext.asInstanceOf[HiveContext] + sc.catalog.createTable(database, tableName, query.output, false) + // Get the Metastore Relation + sc.catalog.lookupRelation(Some(database), tableName, None) match { + case LowerCaseSchema(r: MetastoreRelation) => r + case o: MetastoreRelation => o + } + } + + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + insertIntoRelation(metastoreRelation).execute + Seq.empty[Row] + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } + + override def argString: String = { + s"[Database:$database, TableName: $tableName, InsertIntoHiveTable]\n" + query.toString + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 39033bdeac4b0..a284a91a91e31 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -53,9 +53,9 @@ case class InsertIntoHiveTable( (@transient sc: HiveContext) extends UnaryNode { - val outputClass = newSerializer(table.tableDesc).getSerializedClass - @transient private val hiveContext = new Context(sc.hiveconf) - @transient private val db = Hive.get(sc.hiveconf) + @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass + @transient private lazy val hiveContext = new Context(sc.hiveconf) + @transient private lazy val db = Hive.get(sc.hiveconf) private def newSerializer(tableDesc: TableDesc): Serializer = { val serializer = tableDesc.getDeserializerClass.newInstance().asInstanceOf[Serializer] diff --git a/sql/hive/src/test/resources/golden/Read Partitioned with AvroSerDe-0-e4501461c855cc9071a872a64186c3de b/sql/hive/src/test/resources/golden/Read Partitioned with AvroSerDe-0-e4501461c855cc9071a872a64186c3de new file mode 100644 index 0000000000000..49c8434730ffa --- /dev/null +++ b/sql/hive/src/test/resources/golden/Read Partitioned with AvroSerDe-0-e4501461c855cc9071a872a64186c3de @@ -0,0 +1,8 @@ +The Eleventh Hour 3 April 2010 11 1 +The Doctor's Wife 14 May 2011 11 1 +Horror of Fang Rock 3 September 1977 4 1 +An Unearthly Child 23 November 1963 1 1 +The Mysterious Planet 6 September 1986 6 1 +Rose 26 March 2005 9 1 +The Power of the Daleks 5 November 1966 2 1 +Castrolava 4 January 1982 5 1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 671c3b162f875..79cc7a3fcc7d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -250,9 +250,9 @@ abstract class HiveComparisonTest } try { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") - if (reset) { TestHive.reset() } + if (reset) { + TestHive.reset() + } val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6bf8d18a5c32c..8c8a8b124ac69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -295,8 +295,16 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") test("implement identity function using case statement") { - val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet - val expected = sql("SELECT key FROM src").collect().toSet + val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + + val expected = sql("SELECT key FROM src") + .map { case Row(i: Int) => i } + .collect() + .toSet + assert(actual === expected) } @@ -559,9 +567,9 @@ class HiveQuerySuite extends HiveComparisonTest { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = - rdd.collect().map { - case Row(key: String, value: String) => key -> value + def collectResults(rdd: SchemaRDD): Set[(String, String)] = + rdd.collect().map { + case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value }.toSet clear() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index 8bc72384a64ee..7486bfa82b00b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -37,4 +37,6 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { createQueryTest("Read with RegexSerDe", "SELECT * FROM sales") createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") + + createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index b6b8592344ef5..cc125d539c3c2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -17,47 +17,68 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.hive.test.TestHive -import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkContext._ +import java.io.{DataOutput, DataInput} import java.util -import org.apache.hadoop.fs.{FileSystem, Path} +import java.util.Properties + +import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.serde2.{SerDeStats, AbstractSerDe} -import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.io.Writable import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorFactory, ObjectInspector} -import java.util.Properties + import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import scala.collection.JavaConversions._ -import java.io.{DataOutput, DataInput} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject +import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) + /** * A test suite for Hive custom UDFs. */ class HiveUdfSuite extends HiveComparisonTest { - TestHive.sql( - """ + test("spark sql udf test that returns a struct") { + registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) + assert(sql( + """ + |SELECT getStruct(1).f1, + | getStruct(1).f2, + | getStruct(1).f3, + | getStruct(1).f4, + | getStruct(1).f5 FROM src LIMIT 1 + """.stripMargin).first() === Row(1, 2, 3, 4, 5)) + } + + test("hive struct udf") { + sql( + """ |CREATE EXTERNAL TABLE hiveUdfTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) |ROW FORMAT SERDE '%s' |STORED AS SEQUENCEFILE - """.stripMargin.format(classOf[PairSerDe].getName) - ) - - TestHive.sql( - "ALTER TABLE hiveUdfTestTable ADD IF NOT EXISTS PARTITION(partition='testUdf') LOCATION '%s'" - .format(this.getClass.getClassLoader.getResource("data/files/testUdf").getFile) - ) - - TestHive.sql("CREATE TEMPORARY FUNCTION testUdf AS '%s'".format(classOf[PairUdf].getName)) - - TestHive.sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - - TestHive.sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + """. + stripMargin.format(classOf[PairSerDe].getName)) + + val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + sql(s""" + ALTER TABLE hiveUdfTestTable + ADD IF NOT EXISTS PARTITION(partition='testUdf') + LOCATION '$location'""") + + sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") + sql("SELECT testUdf(pair) FROM hiveUdfTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index b99caf77bce28..679efe082f2a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest + +import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive._ case class Nested1(f1: Nested2) @@ -54,4 +56,11 @@ class SQLQuerySuite extends QueryTest { sql("SELECT f1.f2.f3 FROM nested"), 1) } + + test("test CTAS") { + checkAnswer(sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src"), Seq.empty[Row]) + checkAnswer( + sql("SELECT key, value FROM test_ctas_123 ORDER BY key"), + sql("SELECT key, value FROM src ORDER BY key").collect().toSeq) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala index 0723be7298e15..e380280f301c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -20,14 +20,10 @@ package org.apache.spark.sql.parquet import java.io.File -import org.apache.spark.sql.hive.execution.HiveTableScan import org.scalatest.BeforeAndAfterAll -import scala.reflect.ClassTag - -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ case class ParquetData(intField: Int, stringField: String) @@ -36,27 +32,19 @@ case class ParquetData(intField: Int, stringField: String) * Tests for our SerDe -> Native parquet scan conversion. */ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { - override def beforeAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "true") - } - - override def afterAll(): Unit = { - setConf("spark.sql.hive.convertMetastoreParquet", "false") - } - - val partitionedTableDir = File.createTempFile("parquettests", "sparksql") - partitionedTableDir.delete() - partitionedTableDir.mkdir() - - (1 to 10).foreach { p => - val partDir = new File(partitionedTableDir, s"p=$p") - sparkContext.makeRDD(1 to 10) - .map(i => ParquetData(i, s"part-$p")) - .saveAsParquetFile(partDir.getCanonicalPath) - } - - sql(s""" + val partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sql(s""" create external table partitioned_parquet ( intField INT, @@ -70,7 +58,7 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { location '${partitionedTableDir.getCanonicalPath}' """) - sql(s""" + sql(s""" create external table normal_parquet ( intField INT, @@ -83,8 +71,15 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' """) - (1 to 10).foreach { p => - sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "false") } test("project the partitioning column") { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 457e8ab28ed82..f63560dcb5b89 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver} import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.ui.StreamingTab +import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.MetadataCleaner /** @@ -158,7 +158,14 @@ class StreamingContext private[streaming] ( private[streaming] val waiter = new ContextWaiter - private[streaming] val uiTab = new StreamingTab(this) + private[streaming] val progressListener = new StreamingJobProgressListener(this) + + private[streaming] val uiTab: Option[StreamingTab] = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(new StreamingTab(this)) + } else { + None + } /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 75f0e8716dc7e..e35a568ddf115 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -26,7 +26,7 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { override val metricRegistry = new MetricRegistry override val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) - private val streamingListener = ssc.uiTab.listener + private val streamingListener = ssc.progressListener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, defaultValue: T) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 18605cac7006c..9dc26dc6b32a1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -21,7 +21,7 @@ package org.apache.spark.streaming.api.java import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.io.InputStream +import java.io.{Closeable, InputStream} import java.util.{List => JList, Map => JMap} import akka.actor.{Props, SupervisorStrategy} @@ -49,7 +49,7 @@ import org.apache.spark.streaming.receiver.Receiver * respectively. `context.awaitTransformation()` allows the current thread to wait for the * termination of a context by `stop()` or by an exception. */ -class JavaStreamingContext(val ssc: StreamingContext) { +class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** * Create a StreamingContext. @@ -540,6 +540,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { ssc.stop(stopSparkContext, stopGracefully) } + + override def close(): Unit = stop() + } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 34ac254f337eb..d9d04cd706a04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,18 +17,31 @@ package org.apache.spark.streaming.ui -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.SparkUITab +import org.apache.spark.ui.{SparkUI, SparkUITab} -/** Spark Web UI tab that shows statistics of a streaming job */ +import StreamingTab._ + +/** + * Spark Web UI tab that shows statistics of a streaming job. + * This assumes the given SparkContext has enabled its SparkUI. + */ private[spark] class StreamingTab(ssc: StreamingContext) - extends SparkUITab(ssc.sc.ui, "streaming") with Logging { + extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { - val parent = ssc.sc.ui - val listener = new StreamingJobProgressListener(ssc) + val parent = getSparkUI(ssc) + val listener = ssc.progressListener ssc.addStreamingListener(listener) attachPage(new StreamingPage(this)) parent.attachTab(this) } + +private object StreamingTab { + def getSparkUI(ssc: StreamingContext): SparkUI = { + ssc.sc.ui.getOrElse { + throw new SparkException("Parent SparkUI to attach this tab to not found!") + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index f4e11f975de94..99c8d13231aac 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import scala.language.postfixOps import org.apache.spark.SparkConf import org.apache.spark.storage.{StorageLevel, StreamBlockId} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 7b33d3b235466..a3cabd6be02fe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -29,8 +29,6 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import scala.language.postfixOps - class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 2861f5335ae36..84fed95a75e67 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming import scala.collection.mutable.ArrayBuffer import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global -import scala.language.postfixOps import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala index 2a0db7564915d..8e30118266855 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala @@ -18,19 +18,27 @@ package org.apache.spark.streaming import scala.io.Source -import scala.language.postfixOps import org.scalatest.FunSuite import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkConf + class UISuite extends FunSuite { // Ignored: See SPARK-1530 ignore("streaming tab in spark UI") { - val ssc = new StreamingContext("local", "test", Seconds(1)) + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.ui.enabled", "true") + val ssc = new StreamingContext(conf, Seconds(1)) + assert(ssc.sc.ui.isDefined, "Spark UI is not started!") + val ui = ssc.sc.ui.get + eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(ssc.sparkContext.ui.appUIAddress).mkString + val html = Source.fromURL(ui.appUIAddress).mkString assert(!html.contains("random data that should not be present")) // test if streaming tab exist assert(html.toLowerCase.contains("streaming")) @@ -39,8 +47,7 @@ class UISuite extends FunSuite { } eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL( - ssc.sparkContext.ui.appUIAddress.stripSuffix("/") + "/streaming").mkString + val html = Source.fromURL(ui.appUIAddress.stripSuffix("/") + "/streaming").mkString assert(html.toLowerCase.contains("batch")) assert(html.toLowerCase.contains("network")) } diff --git a/tools/pom.xml b/tools/pom.xml index f36674476770c..b90eb0ca250c5 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -63,6 +63,20 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-source-plugin diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index bcf6d43ab34eb..595ded6ae67fa 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.collection.JavaConversions._ import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} +import scala.util.Try /** * A tool for generating classes to be excluded during binary checking with MIMA. It is expected @@ -121,12 +122,17 @@ object GenerateMIMAIgnore { } def main(args: Array[String]) { + import scala.tools.nsc.io.File val (privateClasses, privateMembers) = privateWithin("org.apache.spark") - scala.tools.nsc.io.File(".generated-mima-class-excludes"). - writeAll(privateClasses.mkString("\n")) + val previousContents = Try(File(".generated-mima-class-excludes").lines()). + getOrElse(Iterator.empty).mkString("\n") + File(".generated-mima-class-excludes") + .writeAll(previousContents + privateClasses.mkString("\n")) println("Created : .generated-mima-class-excludes in current directory.") - scala.tools.nsc.io.File(".generated-mima-member-excludes"). - writeAll(privateMembers.mkString("\n")) + val previousMembersContents = Try(File(".generated-mima-member-excludes").lines) + .getOrElse(Iterator.empty).mkString("\n") + File(".generated-mima-member-excludes").writeAll(previousMembersContents + + privateMembers.mkString("\n")) println("Created : .generated-mima-member-excludes in current directory.") } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 10fc39bba87d1..aff9ab71f0937 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -103,14 +103,6 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa appContext } - def calculateAMMemory(newApp: GetNewApplicationResponse): Int = { - val minResMemory = newApp.getMinimumResourceCapability().getMemory() - val amMemory = ((args.amMemory / minResMemory) * minResMemory) + - ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - memoryOverhead) - amMemory - } - def setupSecurityToken(amContainer: ContainerLaunchContext) = { // Setup security tokens. val dob = new DataOutputBuffer() diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 5a1b42c1e17d5..6c93d8582330b 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -48,16 +48,17 @@ private[yarn] class YarnAllocationHandler( private val lastResponseId = new AtomicInteger() private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() - override protected def allocateContainers(count: Int): YarnAllocateResponse = { + override protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse = { var resourceRequests: List[ResourceRequest] = null - logDebug("numExecutors: " + count) + logDebug("asking for additional executors: " + count + " with already pending: " + pending) + val totalNumAsk = count + pending if (count <= 0) { resourceRequests = List() } else if (preferredHostToCount.isEmpty) { logDebug("host preferences is empty") resourceRequests = List(createResourceRequest( - AllocationType.ANY, null, count, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)) + AllocationType.ANY, null, totalNumAsk, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)) } else { // request for all hosts in preferred nodes and for numExecutors - // candidates.size, request by default allocation policy. @@ -80,7 +81,7 @@ private[yarn] class YarnAllocationHandler( val anyContainerRequests: ResourceRequest = createResourceRequest( AllocationType.ANY, resource = null, - count, + totalNumAsk, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest]( @@ -103,7 +104,7 @@ private[yarn] class YarnAllocationHandler( req.addAllReleases(releasedContainerList) if (count > 0) { - logInfo("Allocating %d executor containers with %d of memory each.".format(count, + logInfo("Allocating %d executor containers with %d of memory each.".format(totalNumAsk, executorMemory + memoryOverhead)) } else { logDebug("Empty allocation req .. release : " + releasedContainerList) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 5756263e89e21..cde5fff637a39 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -21,12 +21,8 @@ import java.io.IOException import java.net.Socket import java.util.concurrent.atomic.AtomicReference -import scala.collection.JavaConversions._ -import scala.util.Try - import akka.actor._ import akka.remote._ -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ @@ -107,8 +103,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } } - // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using. - ShutdownHookManager.get().addShutdownHook(cleanupHook, 30) + + // Use higher priority than FileSystem. + assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) + ShutdownHookManager + .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) // Call this to force generation of secret so it gets populated into the // Hadoop UGI. This has to happen before the startUserClass which does a @@ -189,7 +188,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, if (sc == null) { finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.") } else { - registerAM(sc.ui.appUIAddress, securityMgr) + registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) try { userThread.join() } finally { @@ -283,11 +282,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } val sparkContext = sparkContextRef.get() - assert(sparkContext != null || count >= numTries) if (sparkContext == null) { - logError( - "Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d".format( - count * waitTime, numTries)) + logError(("SparkContext did not initialize after waiting for %d ms. Please check earlier" + + " log output for errors. Failing the application.").format(numTries * waitTime)) } sparkContext } @@ -409,6 +406,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, object ApplicationMaster extends Logging { + val SHUTDOWN_HOOK_PRIORITY: Int = 30 + private var master: ApplicationMaster = _ def main(args: Array[String]) = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 8075b7a7fb837..c96f731923d22 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -300,8 +300,6 @@ trait ClientBase extends Logging { retval.toString } - def calculateAMMemory(newApp: GetNewApplicationResponse): Int - def setupSecurityToken(amContainer: ContainerLaunchContext) def createContainerLaunchContext( @@ -346,7 +344,7 @@ trait ClientBase extends Logging { } amContainer.setEnvironment(env) - val amMemory = calculateAMMemory(newApp) + val amMemory = args.amMemory val javaOpts = ListBuffer[String]() diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 0b8744f4b8bdf..299e38a5eb9c0 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -112,6 +112,9 @@ private[yarn] abstract class YarnAllocator( def allocateResources() = { val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get() + // this is needed by alpha, do it here since we add numPending right after this + val executorsPending = numPendingAllocate.get() + if (missing > 0) { numPendingAllocate.addAndGet(missing) logInfo("Will Allocate %d executor containers, each with %d memory".format( @@ -121,7 +124,7 @@ private[yarn] abstract class YarnAllocator( logDebug("Empty allocation request ...") } - val allocateResponse = allocateContainers(missing) + val allocateResponse = allocateContainers(missing, executorsPending) val allocatedContainers = allocateResponse.getAllocatedContainers() if (allocatedContainers.size > 0) { @@ -435,9 +438,10 @@ private[yarn] abstract class YarnAllocator( * * @param count Number of containers to allocate. * If zero, should still contact RM (as a heartbeat). + * @param pending Number of containers pending allocate. Only used on alpha. * @return Response to the allocation request. */ - protected def allocateContainers(count: Int): YarnAllocateResponse + protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse /** Called to release a previously allocated container. */ protected def releaseContainer(container: Container): Unit diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 41c662cd7a6de..6aa6475fe4a18 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -55,7 +55,7 @@ private[spark] class YarnClientSchedulerBackend( val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) + sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIHostPort) } val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index 68cc2890f3a22..5480eca7c832c 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -238,9 +238,6 @@ class ClientBaseSuite extends FunSuite with Matchers { val sparkConf: SparkConf, val yarnConf: YarnConfiguration) extends ClientBase { - override def calculateAMMemory(newApp: GetNewApplicationResponse): Int = - throw new UnsupportedOperationException() - override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = throw new UnsupportedOperationException() diff --git a/yarn/pom.xml b/yarn/pom.xml index 7fcd7ee0d4547..815a736c2e8fd 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -88,6 +88,20 @@ + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.codehaus.mojo build-helper-maven-plugin diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 313a0d21ce181..82e45e3e7ad54 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -103,15 +103,6 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa clusterMetrics.getNumNodeManagers) } - def calculateAMMemory(newApp: GetNewApplicationResponse) :Int = { - // TODO: Need a replacement for the following code to fix -Xmx? - // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() - // var amMemory = ((args.amMemory / minResMemory) * minResMemory) + - // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - // memoryOverhead ) - args.amMemory - } - def setupSecurityToken(amContainer: ContainerLaunchContext) = { // Setup security tokens. val dob = new DataOutputBuffer() diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 5438f151ac0ad..e44a8db41b97e 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -47,7 +47,8 @@ private[yarn] class YarnAllocationHandler( amClient.releaseAssignedContainer(container.getId()) } - override protected def allocateContainers(count: Int): YarnAllocateResponse = { + // pending isn't used on stable as the AMRMClient handles incremental asks + override protected def allocateContainers(count: Int, pending: Int): YarnAllocateResponse = { addResourceRequests(count) // We have already set the container request. Poll the ResourceManager for a response.