diff --git a/LICENSE b/LICENSE index 1c166d1333614..1c1c2c0255fa9 100644 --- a/LICENSE +++ b/LICENSE @@ -396,3 +396,35 @@ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================== +For sbt and sbt-launch-lib.bash in sbt/: +======================================================================== + +// Generated from http://www.opensource.org/licenses/bsd-license.php +Copyright (c) 2011, Paul Phillips. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the author nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/NOTICE b/NOTICE index 7cbb114b2ae2d..dce0c4eaf31ed 100644 --- a/NOTICE +++ b/NOTICE @@ -1,5 +1,5 @@ Apache Spark -Copyright 2013 The Apache Software Foundation. +Copyright 2014 The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index c840a68f76b17..dc8135b9b8b51 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # Apache Spark -Lightning-Fast Cluster Computing - +Lightning-Fast Cluster Computing - ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the project webpage at . +guide, on the project webpage at . This README file only contains basic setup instructions. @@ -92,21 +92,10 @@ If your project is built with Maven, add this to your POM file's ` ## Configuration -Please refer to the [Configuration guide](http://spark.incubator.apache.org/docs/latest/configuration.html) +Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. -## Apache Incubator Notice - -Apache Spark is an effort undergoing incubation at The Apache Software -Foundation (ASF), sponsored by the Apache Incubator. Incubation is required of -all newly accepted projects until a further review indicates that the -infrastructure, communications, and decision making process have stabilized in -a manner consistent with other successful ASF projects. While incubation status -is not necessarily a reflection of the completeness or stability of the code, -it does indicate that the project has yet to be fully endorsed by the ASF. - - ## Contributing to Spark Contributions via GitHub pull requests are gladly accepted from their original diff --git a/assembly/pom.xml b/assembly/pom.xml index 82396040251d3..22bbbc57d81d4 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,17 +21,20 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml org.apache.spark spark-assembly_2.10 Spark Project Assembly - http://spark.incubator.apache.org/ + http://spark.apache.org/ + pom - ${project.build.directory}/scala-${scala.binary.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar + scala-${scala.binary.version} + ${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar + ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} spark /usr/share/spark root @@ -155,6 +158,16 @@ + + spark-ganglia-lgpl + + + org.apache.spark + spark-ganglia-lgpl_${scala.binary.version} + ${project.version} + + + bigtop-dist + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index dd3eed8affe39..70c7474a936dc 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -27,7 +27,7 @@ object Bagel extends Logging { /** * Runs a Bagel program. - * @param sc [[org.apache.spark.SparkContext]] to use for the program. + * @param sc org.apache.spark.SparkContext to use for the program. * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the * Key will be the vertex id. * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often @@ -38,10 +38,10 @@ object Bagel extends Logging { * @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices * after each superstep and provides the result to each vertex in the next * superstep. - * @param partitioner [[org.apache.spark.Partitioner]] partitions values by key + * @param partitioner org.apache.spark.Partitioner partitions values by key * @param numPartitions number of partitions across which to split the graph. * Default is the default parallelism of the SparkContext - * @param storageLevel [[org.apache.spark.storage.StorageLevel]] to use for caching of + * @param storageLevel org.apache.spark.storage.StorageLevel to use for caching of * intermediate RDDs in each superstep. Defaults to caching in memory. * @param compute function that takes a Vertex, optional set of (possibly combined) messages to * the Vertex, optional Aggregator and the current superstep, @@ -131,7 +131,7 @@ object Bagel extends Logging { /** * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default - * [[org.apache.spark.HashPartitioner]] and default storage level + * org.apache.spark.HashPartitioner and default storage level */ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( sc: SparkContext, @@ -146,7 +146,7 @@ object Bagel extends Logging { /** * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the - * default [[org.apache.spark.HashPartitioner]] + * default org.apache.spark.HashPartitioner */ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( sc: SparkContext, @@ -166,7 +166,7 @@ object Bagel extends Logging { /** * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * default [[org.apache.spark.HashPartitioner]], + * default org.apache.spark.HashPartitioner, * [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level */ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( @@ -180,7 +180,7 @@ object Bagel extends Logging { /** * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * the default [[org.apache.spark.HashPartitioner]] + * the default org.apache.spark.HashPartitioner * and [[org.apache.spark.bagel.DefaultCombiner]] */ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( diff --git a/bin/spark-class b/bin/spark-class index c4225a392d6da..229ae2cebbab3 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -40,34 +40,46 @@ if [ -z "$1" ]; then exit 1 fi -# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable -# values for that; it doesn't need a lot -if [ "$1" = "org.apache.spark.deploy.master.Master" -o "$1" = "org.apache.spark.deploy.worker.Worker" ]; then - SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} - SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" - # Do not overwrite SPARK_JAVA_OPTS environment variable in this script - OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default -else - OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" +if [ -n "$SPARK_MEM" ]; then + echo "Warning: SPARK_MEM is deprecated, please use a more specific config option" + echo "(e.g., spark.executor.memory or SPARK_DRIVER_MEMORY)." fi +# Use SPARK_MEM or 512m as the default memory, to be overridden by specific options +DEFAULT_MEM=${SPARK_MEM:-512m} + +SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" -# Add java opts for master, worker, executor. The opts maybe null +# Add java opts and memory settings for master, worker, executors, and repl. case "$1" in + # Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. 'org.apache.spark.deploy.master.Master') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS" + OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_MASTER_OPTS" + OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} ;; 'org.apache.spark.deploy.worker.Worker') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS" + OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_WORKER_OPTS" + OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} ;; + + # Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. 'org.apache.spark.executor.CoarseGrainedExecutorBackend') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} ;; 'org.apache.spark.executor.MesosExecutorBackend') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS" + OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM} ;; + + # All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS. 'org.apache.spark.repl.Main') - OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS" + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS" + OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} + ;; + *) + OUR_JAVA_OPTS="$SPARK_JAVA_OPTS" + OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} ;; esac @@ -83,14 +95,10 @@ else fi fi -# Set SPARK_MEM if it isn't already set since we also use it for this process -SPARK_MEM=${SPARK_MEM:-512m} -export SPARK_MEM - # Set JAVA_OPTS to be able to load native libraries and to set heap size JAVA_OPTS="$OUR_JAVA_OPTS" JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH" -JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM" +JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e "$FWDIR/conf/java-opts" ] ; then JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 80818c78ec24b..f488cfdbeceb6 100755 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -34,22 +34,45 @@ if not "x%1"=="x" goto arg_given goto exit :arg_given -set RUNNING_DAEMON=0 -if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1 -if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1 -if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m +if not "x%SPARK_MEM%"=="x" ( + echo Warning: SPARK_MEM is deprecated, please use a more specific config option + echo e.g., spark.executor.memory or SPARK_DRIVER_MEMORY. +) + +rem Use SPARK_MEM or 512m as the default memory, to be overridden by specific options +set OUR_JAVA_MEM=%SPARK_MEM% +if "x%OUR_JAVA_MEM%"=="x" set OUR_JAVA_MEM=512m + set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true -if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY% -rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script -if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% -if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -rem Figure out how much memory to use per executor and set it as an environment -rem variable so that our process sees it and can report it to Mesos -if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m +rem Add java opts and memory settings for master, worker, executors, and repl. +rem Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. +if "%1"=="org.apache.spark.deploy.master.Master" ( + set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_MASTER_OPTS% + if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% +) else if "%1"=="org.apache.spark.deploy.worker.Worker" ( + set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_WORKER_OPTS% + if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% + +rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. +) else if "%1"=="org.apache.spark.executor.CoarseGrainedExecutorBackend" ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% + if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% +) else if "%1"=="org.apache.spark.executor.MesosExecutorBackend" ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS% + if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY% + +rem All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS. +) else if "%1"=="org.apache.spark.repl.Main" ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_REPL_OPTS% + if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% +) else ( + set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% + if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY% +) rem Set JAVA_OPTS to be able to load native libraries and to set heap size -set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM% +set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM% rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala! rem Test whether the user has built Spark diff --git a/bin/spark-shell b/bin/spark-shell index 2bff06cf70051..7d3fe3aca7f1d 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -45,13 +45,11 @@ if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then exit fi -SPARK_SHELL_OPTS="" - for o in "$@"; do if [ "$1" = "-c" -o "$1" = "--cores" ]; then shift if [[ "$1" =~ $CORE_PATTERN ]]; then - SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.cores.max=$1" + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.cores.max=$1" shift else echo "ERROR: wrong format for -c/--cores" @@ -61,7 +59,7 @@ for o in "$@"; do if [ "$1" = "-em" -o "$1" = "--execmem" ]; then shift if [[ $1 =~ $MEM_PATTERN ]]; then - SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.executor.memory=$1" + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.executor.memory=$1" shift else echo "ERROR: wrong format for --execmem/-em" @@ -71,7 +69,7 @@ for o in "$@"; do if [ "$1" = "-dm" -o "$1" = "--drivermem" ]; then shift if [[ $1 =~ $MEM_PATTERN ]]; then - export SPARK_MEM=$1 + export SPARK_DRIVER_MEMORY=$1 shift else echo "ERROR: wrong format for --drivermem/-dm" @@ -125,16 +123,18 @@ if [[ ! $? ]]; then fi if $cygwin; then - # Workaround for issue involving JLine and Cygwin - # (see http://sourceforge.net/p/jline/bugs/40/). - # If you're using the Mintty terminal emulator in Cygwin, may need to set the - # "Backspace sends ^H" setting in "Keys" section of the Mintty options - # (see https://github.com/sbt/sbt/issues/562). - stty -icanon min 1 -echo > /dev/null 2>&1 - $FWDIR/bin/spark-class -Djline.terminal=unix $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@" - stty icanon echo > /dev/null 2>&1 + # Workaround for issue involving JLine and Cygwin + # (see http://sourceforge.net/p/jline/bugs/40/). + # If you're using the Mintty terminal emulator in Cygwin, may need to set the + # "Backspace sends ^H" setting in "Keys" section of the Mintty options + # (see https://github.com/sbt/sbt/issues/562). + stty -icanon min 1 -echo > /dev/null 2>&1 + export SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Djline.terminal=unix" + $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" + stty icanon echo > /dev/null 2>&1 else - $FWDIR/bin/spark-class $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@" + export SPARK_REPL_OPTS + $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" fi # record the exit status lest it be overwritten: diff --git a/core/pom.xml b/core/pom.xml index 5576b0c3b4795..a6f478b09bda0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -17,242 +17,264 @@ --> - 4.0.0 - - org.apache.spark - spark-parent - 1.0.0-incubating-SNAPSHOT - ../pom.xml - - + 4.0.0 + org.apache.spark - spark-core_2.10 - jar - Spark Project Core - http://spark.incubator.apache.org/ + spark-parent + 1.0.0-SNAPSHOT + ../pom.xml + - - - org.apache.hadoop - hadoop-client - - - net.java.dev.jets3t - jets3t - - - commons-logging - commons-logging - - - - - org.apache.avro - avro - - - org.apache.avro - avro-ipc - - - org.apache.zookeeper - zookeeper - - - org.eclipse.jetty - jetty-server - - - com.google.guava - guava - - - com.google.code.findbugs - jsr305 - - - org.slf4j - slf4j-api - - - org.slf4j - jul-to-slf4j - - - org.slf4j - jcl-over-slf4j - - - log4j - log4j - - - org.slf4j - slf4j-log4j12 - - - com.ning - compress-lzf - - - org.xerial.snappy - snappy-java - - - org.ow2.asm - asm - - - com.twitter - chill_${scala.binary.version} - 0.3.1 - - - com.twitter - chill-java - 0.3.1 - - - ${akka.group} - akka-remote_${scala.binary.version} - - - ${akka.group} - akka-slf4j_${scala.binary.version} - - - ${akka.group} - akka-testkit_${scala.binary.version} - test - - - org.scala-lang - scala-library - - - net.liftweb - lift-json_${scala.binary.version} - - - it.unimi.dsi - fastutil - - - colt - colt - - - org.apache.mesos - mesos - - - io.netty - netty-all - - - com.clearspring.analytics - stream - - - com.codahale.metrics - metrics-core - - - com.codahale.metrics - metrics-jvm - - - com.codahale.metrics - metrics-json - - - com.codahale.metrics - metrics-ganglia - - - com.codahale.metrics - metrics-graphite - - - org.apache.derby - derby - test - - - commons-io - commons-io - test - - - org.scalatest - scalatest_${scala.binary.version} - test - - - org.mockito - mockito-all - test - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.easymock - easymock - test - - - com.novocode - junit-interface - test - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - - org.apache.maven.plugins - maven-antrun-plugin - - - test - - run - - - true - - - - - - - - - - - - - - - - - - - - org.scalatest - scalatest-maven-plugin - - - ${basedir}/.. - 1 - ${spark.classpath} - - - - - + org.apache.spark + spark-core_2.10 + jar + Spark Project Core + http://spark.apache.org/ + + + + yarn-alpha + + + org.apache.avro + avro + + + + + + + + org.apache.hadoop + hadoop-client + + + net.java.dev.jets3t + jets3t + + + commons-logging + commons-logging + + + + + org.apache.curator + curator-recipes + + + org.eclipse.jetty + jetty-plus + + + org.eclipse.jetty + jetty-security + + + org.eclipse.jetty + jetty-util + + + org.eclipse.jetty + jetty-server + + + com.google.guava + guava + + + com.google.code.findbugs + jsr305 + + + org.slf4j + slf4j-api + + + org.slf4j + jul-to-slf4j + + + org.slf4j + jcl-over-slf4j + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + com.ning + compress-lzf + + + org.xerial.snappy + snappy-java + + + com.twitter + chill_${scala.binary.version} + 0.3.1 + + + com.twitter + chill-java + 0.3.1 + + + commons-net + commons-net + + + ${akka.group} + akka-remote_${scala.binary.version} + + + ${akka.group} + akka-slf4j_${scala.binary.version} + + + ${akka.group} + akka-testkit_${scala.binary.version} + test + + + org.scala-lang + scala-library + + + org.json4s + json4s-jackson_${scala.binary.version} + 3.2.6 + + + + org.scala-lang + scalap + + + + + it.unimi.dsi + fastutil + + + colt + colt + + + org.apache.mesos + mesos + + + io.netty + netty-all + + + com.clearspring.analytics + stream + + + com.codahale.metrics + metrics-core + + + com.codahale.metrics + metrics-jvm + + + com.codahale.metrics + metrics-json + + + com.codahale.metrics + metrics-graphite + + + org.apache.derby + derby + test + + + commons-io + commons-io + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.mockito + mockito-all + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.easymock + easymock + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-antrun-plugin + + + test + + run + + + true + + + + + + + + + + + + + + + + + + + + org.scalatest + scalatest-maven-plugin + + + ${basedir}/.. + 1 + ${spark.classpath} + + + + + diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java similarity index 69% rename from core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala rename to core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java index 7500a8943634b..57fd0a7a80494 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -15,16 +15,13 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package org.apache.spark.api.java.function; -import java.lang.{Double => JDouble, Iterable => JIterable} +import java.io.Serializable; /** * A function that returns zero or more records of type Double from each input record. */ -// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is -// overloaded for both FlatMapFunction and DoubleFlatMapFunction. -abstract class DoubleFlatMapFunction[T] extends WrappedFunction1[T, JIterable[JDouble]] - with Serializable { - // Intentionally left blank +public interface DoubleFlatMapFunction extends Serializable { + public Iterable call(T t) throws Exception; } diff --git a/project/project/SparkPluginBuild.scala b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java similarity index 69% rename from project/project/SparkPluginBuild.scala rename to core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java index a88a5e14539ec..150144e0e418c 100644 --- a/project/project/SparkPluginBuild.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java @@ -15,12 +15,13 @@ * limitations under the License. */ -import sbt._ +package org.apache.spark.api.java.function; -object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(junitXmlListener) - /* This is not published in a Maven repository, so we get it from GitHub directly */ - lazy val junitXmlListener = uri( - "https://github.com/chenkelmann/junit_xml_listener.git#3f8029fbfda54dc7a68b1afd2f885935e1090016" - ) +import java.io.Serializable; + +/** + * A function that returns Doubles, and can be used to construct DoubleRDDs. + */ +public interface DoubleFunction extends Serializable { + public double call(T t) throws Exception; } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java similarity index 79% rename from core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index bdb01f7670356..fa75842047c6a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package org.apache.spark.api.java.function; -import scala.reflect.ClassTag +import java.io.Serializable; /** * A function that returns zero or more output records from each input record. */ -abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { - def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]] -} +public interface FlatMapFunction extends Serializable { + public Iterable call(T t) throws Exception; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java similarity index 78% rename from core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index aae1349c5e17c..d1fdec072443d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package org.apache.spark.api.java.function; -import scala.reflect.ClassTag +import java.io.Serializable; /** * A function that takes two inputs and returns zero or more output records. */ -abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { - def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]] -} +public interface FlatMapFunction2 extends Serializable { + public Iterable call(T1 t1, T2 t2) throws Exception; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.scala b/core/src/main/java/org/apache/spark/api/java/function/Function.java similarity index 72% rename from core/src/main/scala/org/apache/spark/api/java/function/Function.scala rename to core/src/main/java/org/apache/spark/api/java/function/Function.java index a5e1701f7718f..d00551bb0add6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java @@ -15,17 +15,15 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package org.apache.spark.api.java.function; -import scala.reflect.ClassTag -import org.apache.spark.api.java.JavaSparkContext +import java.io.Serializable; /** - * Base class for functions whose return types do not create special RDDs. PairFunction and + * Base interface for functions whose return types do not create special RDDs. PairFunction and * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed * when mapping RDDs of other types. */ -abstract class Function[T, R] extends WrappedFunction1[T, R] with Serializable { - def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag +public interface Function extends Serializable { + public R call(T1 v1) throws Exception; } - diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.scala b/core/src/main/java/org/apache/spark/api/java/function/Function2.java similarity index 76% rename from core/src/main/scala/org/apache/spark/api/java/function/Function2.scala rename to core/src/main/java/org/apache/spark/api/java/function/Function2.java index fa3616cbcb4d2..793caaa61ac5a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/Function2.java @@ -15,15 +15,13 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package org.apache.spark.api.java.function; -import scala.reflect.ClassTag -import org.apache.spark.api.java.JavaSparkContext +import java.io.Serializable; /** * A two-argument function that takes arguments of type T1 and T2 and returns an R. */ -abstract class Function2[T1, T2, R] extends WrappedFunction2[T1, T2, R] with Serializable { - def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag +public interface Function2 extends Serializable { + public R call(T1 v1, T2 v2) throws Exception; } - diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.scala b/core/src/main/java/org/apache/spark/api/java/function/Function3.java similarity index 75% rename from core/src/main/scala/org/apache/spark/api/java/function/Function3.scala rename to core/src/main/java/org/apache/spark/api/java/function/Function3.java index 45152891e9272..b4151c3417df4 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function3.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/Function3.java @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package org.apache.spark.api.java.function; -import org.apache.spark.api.java.JavaSparkContext -import scala.reflect.ClassTag +import java.io.Serializable; /** * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R. */ -abstract class Function3[T1, T2, T3, R] extends WrappedFunction3[T1, T2, T3, R] with Serializable { - def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag +public interface Function3 extends Serializable { + public R call(T1 v1, T2 v2, T3 v3) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java new file mode 100644 index 0000000000000..691ef2eceb1f6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +import scala.Tuple2; + +/** + * A function that returns zero or more key-value pair records from each input record. The + * key-value pairs are represented as scala.Tuple2 objects. + */ +public interface PairFlatMapFunction extends Serializable { + public Iterable> call(T t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java new file mode 100644 index 0000000000000..abd9bcc07ac61 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +import scala.Tuple2; + +/** + * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs. + */ +public interface PairFunction extends Serializable { + public Tuple2 call(T t) throws Exception; +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java similarity index 77% rename from core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala rename to core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java index 2e0b0e6eda765..2a10435b7523a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.api.java.function; -private[spark] trait BlockFetchTracker { - def totalBlocks : Int - def numLocalBlocks: Int - def numRemoteBlocks: Int - def remoteFetchTime : Long - def fetchWaitTime: Long - def remoteBytesRead : Long +import java.io.Serializable; + +/** + * A function with no return value. + */ +public interface VoidFunction extends Serializable { + public void call(T t) throws Exception; } diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala index 754b46a4c7df2..a67392441ed29 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala @@ -79,7 +79,6 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin val completionIter = CompletionIterator[T, Iterator[T]](itr, { val shuffleMetrics = new ShuffleReadMetrics shuffleMetrics.shuffleFinishTime = System.currentTimeMillis - shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 1daabecf23292..872e892c04fe6 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -71,10 +71,30 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val computedValues = rdd.computeOrReadCheckpoint(split, context) // Persist the result, so long as the task is not running locally if (context.runningLocally) { return computedValues } - val elements = new ArrayBuffer[Any] - elements ++= computedValues - blockManager.put(key, elements, storageLevel, tellMaster = true) - elements.iterator.asInstanceOf[Iterator[T]] + if (storageLevel.useDisk && !storageLevel.useMemory) { + // In the case that this RDD is to be persisted using DISK_ONLY + // the iterator will be passed directly to the blockManager (rather then + // caching it to an ArrayBuffer first), then the resulting block data iterator + // will be passed back to the user. If the iterator generates a lot of data, + // this means that it doesn't all have to be held in memory at one time. + // This could also apply to MEMORY_ONLY_SER storage, but we need to make sure + // blocks aren't dropped by the block store before enabling that. + blockManager.put(key, computedValues, storageLevel, tellMaster = true) + return blockManager.get(key) match { + case Some(values) => + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) + case None => + logInfo("Failure to store %s".format(key)) + throw new Exception("Block manager failed to return persisted valued") + } + } else { + // In this case the RDD is cached to an array buffer. This will save the results + // if we're dealing with a 'one-time' iterator + val elements = new ArrayBuffer[Any] + elements ++= computedValues + blockManager.put(key, elements, storageLevel, tellMaster = true) + return elements.iterator.asInstanceOf[Iterator[T]] + } } finally { loading.synchronized { loading.remove(key) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index cc30105940d1a..448f87b81ef4a 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer /** * Base class for dependencies. @@ -43,12 +44,13 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * Represents a dependency on the output of a shuffle stage. * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output - * @param serializerClass class name of the serializer to use + * @param serializer [[Serializer]] to use. If set to null, the default serializer, as specified + * by `spark.serializer` config option, will be used. */ class ShuffleDependency[K, V]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, - val serializerClass: String = null) + val serializer: Serializer = null) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index d3264a4bb3c81..3d7692ea8a49e 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -23,7 +23,7 @@ import com.google.common.io.Files import org.apache.spark.util.Utils -private[spark] class HttpFileServer extends Logging { +private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { var baseDir : File = null var fileDir : File = null @@ -38,9 +38,10 @@ private[spark] class HttpFileServer extends Logging { fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir) + httpServer = new HttpServer(baseDir, securityManager) httpServer.start() serverUri = httpServer.uri + logDebug("HTTP file server started at: " + serverUri) } def stop() { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 759e68ee0cc61..cb5df25fa48df 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,15 +19,18 @@ package org.apache.spark import java.io.File +import org.eclipse.jetty.util.security.{Constraint, Password} +import org.eclipse.jetty.security.authentication.DigestAuthenticator +import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler} + import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.DefaultHandler -import org.eclipse.jetty.server.handler.HandlerList -import org.eclipse.jetty.server.handler.ResourceHandler +import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils + /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. */ @@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -private[spark] class HttpServer(resourceBase: File) extends Logging { +private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) + extends Logging { private var server: Server = null private var port: Int = -1 @@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { server.setThreadPool(threadPool) val resHandler = new ResourceHandler resHandler.setResourceBase(resourceBase.getAbsolutePath) + val handlerList = new HandlerList handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - server.setHandler(handlerList) + + if (securityManager.isAuthenticationEnabled()) { + logDebug("HttpServer is using security") + val sh = setupSecurityHandler(securityManager) + // make sure we go through security handler to get resources + sh.setHandler(handlerList) + server.setHandler(sh) + } else { + logDebug("HttpServer is not using security") + server.setHandler(handlerList) + } + server.start() port = server.getConnectors()(0).getLocalPort() } } + /** + * Setup Jetty to the HashLoginService using a single user with our + * shared secret. Configure it to use DIGEST-MD5 authentication so that the password + * isn't passed in plaintext. + */ + private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = { + val constraint = new Constraint() + // use DIGEST-MD5 as the authentication mechanism + constraint.setName(Constraint.__DIGEST_AUTH) + constraint.setRoles(Array("user")) + constraint.setAuthenticate(true) + constraint.setDataConstraint(Constraint.DC_NONE) + + val cm = new ConstraintMapping() + cm.setConstraint(constraint) + cm.setPathSpec("/*") + val sh = new ConstraintSecurityHandler() + + // the hashLoginService lets us do a single user and + // secret right now. This could be changed to use the + // JAASLoginService for other options. + val hashLogin = new HashLoginService() + + val userCred = new Password(securityMgr.getSecretKey()) + if (userCred == null) { + throw new Exception("Error: secret key is null with authentication on") + } + hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user")) + sh.setLoginService(hashLogin) + sh.setAuthenticator(new DigestAuthenticator()); + sh.setConstraintMappings(Array(cm)) + sh + } + def stop() { if (server == null) { throw new ServerStateException("Server is already stopped") diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index b749e5414dab6..7423082e34f47 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.apache.log4j.{LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} +import org.slf4j.impl.StaticLoggerBinder /** * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows @@ -101,9 +102,11 @@ trait Logging { } private def initializeLogging() { - // If Log4j doesn't seem initialized, load a default properties file + // If Log4j is being used, but is not initialized, load a default properties file + val binder = StaticLoggerBinder.getSingleton + val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory") val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements - if (!log4jInitialized) { + if (!log4jInitialized && usingLog4j) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" val classLoader = this.getClass.getClassLoader Option(classLoader.getResource(defaultLogProps)) match { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5968973132942..80cbf951cb70e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -35,13 +35,28 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster) +private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) extends Actor with Logging { + val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + def receive = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) - sender ! tracker.getSerializedMapOutputStatuses(shuffleId) + val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) + val serializedSize = mapOutputStatuses.size + if (serializedSize > maxAkkaFrameSize) { + val msg = s"Map output statuses were $serializedSize bytes which " + + s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)." + + /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception. + * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239) + * will ultimately remove this entire code path. */ + val exception = new SparkException(msg) + logError(msg, exception) + throw exception + } + sender ! mapOutputStatuses case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala new file mode 100644 index 0000000000000..591978c1d3630 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.net.{Authenticator, PasswordAuthentication} +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.Credentials +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.deploy.SparkHadoopUtil + +import scala.collection.mutable.ArrayBuffer + +/** + * Spark class responsible for security. + * + * In general this class should be instantiated by the SparkEnv and most components + * should access it from that. There are some cases where the SparkEnv hasn't been + * initialized yet and this class must be instantiated directly. + * + * Spark currently supports authentication via a shared secret. + * Authentication can be configured to be on via the 'spark.authenticate' configuration + * parameter. This parameter controls whether the Spark communication protocols do + * authentication using the shared secret. This authentication is a basic handshake to + * make sure both sides have the same shared secret and are allowed to communicate. + * If the shared secret is not identical they will not be allowed to communicate. + * + * The Spark UI can also be secured by using javax servlet filters. A user may want to + * secure the UI if it has data that other users should not be allowed to see. The javax + * servlet filter specified by the user can authenticate the user and then once the user + * is logged in, Spark can compare that user versus the view acls to make sure they are + * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * control the behavior of the acls. Note that the person who started the application + * always has view access to the UI. + * + * Spark does not currently support encryption after authentication. + * + * At this point spark has multiple communication protocols that need to be secured and + * different underlying mechanisms are used depending on the protocol: + * + * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. + * Akka remoting allows you to specify a secure cookie that will be exchanged + * and ensured to be identical in the connection handshake between the client + * and the server. If they are not identical then the client will be refused + * to connect to the server. There is no control of the underlying + * authentication mechanism so its not clear if the password is passed in + * plaintext or uses DIGEST-MD5 or some other mechanism. + * Akka also has an option to turn on SSL, this option is not currently supported + * but we could add a configuration option in the future. + * + * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty + * for the HttpServer. Jetty supports multiple authentication mechanisms - + * Basic, Digest, Form, Spengo, etc. It also supports multiple different login + * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService + * to authenticate using DIGEST-MD5 via a single user and the shared secret. + * Since we are using DIGEST-MD5, the shared secret is not passed on the wire + * in plaintext. + * We currently do not support SSL (https), but Jetty can be configured to use it + * so we could add a configuration option for this in the future. + * + * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. + * Any clients must specify the user and password. There is a default + * Authenticator installed in the SecurityManager to how it does the authentication + * and in this case gets the user name and password from the request. + * + * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * exchange messages. For this we use the Java SASL + * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 + * as the authentication mechanism. This means the shared secret is not passed + * over the wire in plaintext. + * Note that SASL is pluggable as to what mechanism it uses. We currently use + * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. + * Spark currently supports "auth" for the quality of protection, which means + * the connection is not supporting integrity or privacy protection (encryption) + * after authentication. SASL also supports "auth-int" and "auth-conf" which + * SPARK could be support in the future to allow the user to specify the quality + * of protection they want. If we support those, the messages will also have to + * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. + * + * Since the connectionManager does asynchronous messages passing, the SASL + * authentication is a bit more complex. A ConnectionManager can be both a client + * and a Server, so for a particular connection is has to determine what to do. + * A ConnectionId was added to be able to track connections and is used to + * match up incoming messages with connections waiting for authentication. + * If its acting as a client and trying to send a message to another ConnectionManager, + * it blocks the thread calling sendMessage until the SASL negotiation has occurred. + * The ConnectionManager tracks all the sendingConnections using the ConnectionId + * and waits for the response from the server and does the handshake. + * + * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters + * can be used. Yarn requires a specific AmIpFilter be installed for security to work + * properly. For non-Yarn deployments, users can write a filter to go through a + * companies normal login service. If an authentication filter is in place then the + * SparkUI can be configured to check the logged in user against the list of users who + * have view acls to see if that user is authorized. + * The filters can also be used for many different purposes. For instance filters + * could be used for logging, encryption, or compression. + * + * The exact mechanisms used to generate/distributed the shared secret is deployment specific. + * + * For Yarn deployments, the secret is automatically generated using the Akka remote + * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed + * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels + * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn + * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn + * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there + * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use + * filters to do authentication. That authentication then happens via the ResourceManager Proxy + * and Spark will use that to do authorization against the view acls. + * + * For other Spark deployments, the shared secret must be specified via the + * spark.authenticate.secret config. + * All the nodes (Master and Workers) and the applications need to have the same shared secret. + * This again is not ideal as one user could potentially affect another users application. + * This should be enhanced in the future to provide better protection. + * If the UI needs to be secured the user needs to install a javax servlet filter to do the + * authentication. Spark will then use that user to compare against the view acls to do + * authorization. If not filter is in place the user is generally null and no authorization + * can take place. + */ + +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { + + // key used to store the spark secret in the Hadoop UGI + private val sparkSecretLookupKey = "sparkCookie" + + private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) + + // always add the current user and SPARK_USER to the viewAcls + private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""), + Option(System.getenv("SPARK_USER")).getOrElse("")) + aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',') + private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet + + private val secretKey = generateSecretKey() + logInfo("SecurityManager, is authentication enabled: " + authOn + + " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) + + // Set our own authenticator to properly negotiate user/password for HTTP connections. + // This is needed by the HTTP client fetching from the HttpServer. Put here so its + // only set once. + if (authOn) { + Authenticator.setDefault( + new Authenticator() { + override def getPasswordAuthentication(): PasswordAuthentication = { + var passAuth: PasswordAuthentication = null + val userInfo = getRequestingURL().getUserInfo() + if (userInfo != null) { + val parts = userInfo.split(":", 2) + passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) + } + return passAuth + } + } + ) + } + + /** + * Generates or looks up the secret key. + * + * The way the key is stored depends on the Spark deployment mode. Yarn + * uses the Hadoop UGI. + * + * For non-Yarn deployments, If the config variable is not set + * we throw an exception. + */ + private def generateSecretKey(): String = { + if (!isAuthenticationEnabled) return null + // first check to see if the secret is already set, else generate a new one if on yarn + val sCookie = if (SparkHadoopUtil.get.isYarnMode) { + val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) + if (secretKey != null) { + logDebug("in yarn mode, getting secret from credentials") + return new Text(secretKey).toString + } else { + logDebug("getSecretKey: yarn mode, secret key from credentials is null") + } + val cookie = akka.util.Crypt.generateSecureCookie + // if we generated the secret then we must be the first so lets set it so t + // gets used by everyone else + SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie) + logInfo("adding secret to credentials in yarn mode") + cookie + } else { + // user must have set spark.authenticate.secret config + sparkConf.getOption("spark.authenticate.secret") match { + case Some(value) => value + case None => throw new Exception("Error: a secret key must be specified via the " + + "spark.authenticate.secret config") + } + } + sCookie + } + + /** + * Check to see if Acls for the UI are enabled + * @return true if UI authentication is enabled, otherwise false + */ + def uiAclsEnabled(): Boolean = uiAclsOn + + /** + * Checks the given user against the view acl list to see if they have + * authorization to view the UI. If the UI acls must are disabled + * via spark.ui.acls.enable, all users have view access. + * + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ + def checkUIViewPermissions(user: String): Boolean = { + if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true + } + + /** + * Check to see if authentication for the Spark communication protocols is enabled + * @return true if authentication is enabled, otherwise false + */ + def isAuthenticationEnabled(): Boolean = authOn + + /** + * Gets the user used for authenticating HTTP connections. + * For now use a single hardcoded user. + * @return the HTTP user as a String + */ + def getHttpUser(): String = "sparkHttpUser" + + /** + * Gets the user used for authenticating SASL connections. + * For now use a single hardcoded user. + * @return the SASL user as a String + */ + def getSaslUser(): String = "sparkSaslUser" + + /** + * Gets the secret key. + * @return the secret key as a String if authentication is enabled, otherwise returns null + */ + def getSecretKey(): String = secretKey +} diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala index e8f756c408889..a4f69b6b22b2c 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala @@ -29,7 +29,7 @@ private[spark] abstract class ShuffleFetcher { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] + serializer: Serializer = SparkEnv.get.serializer): Iterator[T] /** Stop the fetcher */ def stop() {} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a24f07e9a6e9a..852ed8fe1fb91 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -130,6 +130,8 @@ class SparkContext( val isLocal = (master == "local" || master.startsWith("local[")) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.create( conf, @@ -160,19 +162,20 @@ class SparkContext( jars.foreach(addJar) } + def warnSparkMem(value: String): String = { + logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + + "deprecated, please use spark.executor.memory instead.") + value + } + private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_MEM"))) + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) .map(Utils.memoryStringToMb) .getOrElse(512) - if (!conf.contains("spark.executor.memory") && sys.env.contains("SPARK_MEM")) { - logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + - "deprecated, instead use spark.executor.memory") - } - // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() - // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS"); value <- Option(System.getenv(key))) { executorEnvs(key) = value @@ -183,8 +186,9 @@ class SparkContext( value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { executorEnvs(envKey) = value } - // Since memory can be set with a system property too, use that - executorEnvs("SPARK_MEM") = executorMemory + "m" + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" executorEnvs ++= conf.getExecutorEnv // Set SPARK_USER for user who is running SparkContext. @@ -240,6 +244,7 @@ class SparkContext( localProperties.set(props) } + @deprecated("Properties no longer need to be explicitly initialized.", "1.0.0") def initLocalProperties() { localProperties.set(new Properties()) } @@ -308,7 +313,7 @@ class SparkContext( private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) - def initDriverMetrics() { + private def initDriverMetrics() { SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) SparkEnv.get.metricsSystem.registerSource(blockManagerSource) } @@ -350,7 +355,7 @@ class SparkContext( * using the older MapReduce API (`org.apache.hadoop.mapred`). * * @param conf JobConf for setting up the dataset - * @param inputFormatClass Class of the [[InputFormat]] + * @param inputFormatClass Class of the InputFormat * @param keyClass Class of the keys * @param valueClass Class of the values * @param minSplits Minimum number of Hadoop Splits to generate. @@ -633,7 +638,7 @@ class SparkContext( addedFiles(key) = System.currentTimeMillis // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } @@ -735,8 +740,10 @@ class SparkContext( key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") { - // In order for this to work in yarn standalone mode the user must specify the + // yarn-standalone is deprecated, but still supported + if (SparkHadoopUtil.get.isYarnMode() && + (master == "yarn-standalone" || master == "yarn-cluster")) { + // In order for this to work in yarn-cluster mode the user must specify the // --addjars option to the client to upload the file into the distributed cache // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() @@ -825,13 +832,12 @@ class SparkContext( setLocalProperty("externalCallSite", null) } + /** + * Capture the current user callsite and return a formatted version for printing. If the user + * has overridden the call site, this will return the user's version. + */ private[spark] def getCallSite(): String = { - val callSite = getLocalProperty("externalCallSite") - if (callSite == null) { - Utils.formatSparkCallSite - } else { - callSite - } + Option(getLocalProperty("externalCallSite")).getOrElse(Utils.formatCallSiteInfo()) } /** @@ -846,6 +852,9 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean, resultHandler: (Int, U) => Unit) { + partitions.foreach{ p => + require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p") + } val callSite = getCallSite val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) @@ -949,6 +958,9 @@ class SparkContext( resultHandler: (Int, U) => Unit, resultFunc: => R): SimpleFutureAction[R] = { + partitions.foreach{ p => + require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p") + } val cleanF = clean(processPartition) val callSite = getCallSite val waiter = dagScheduler.submitJob( @@ -1024,7 +1036,7 @@ class SparkContext( * The SparkContext object contains a number of implicit conversions and parameters for use with * various Spark features. */ -object SparkContext { +object SparkContext extends Logging { private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" @@ -1242,7 +1254,11 @@ object SparkContext { } scheduler - case "yarn-standalone" => + case "yarn-standalone" | "yarn-cluster" => + if (master == "yarn-standalone") { + logWarning( + "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.") + } val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 7ac65828f670f..774cbd6441a48 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -28,7 +28,7 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.storage.{BlockManager, BlockManagerMaster, BlockManagerMasterActor} import org.apache.spark.network.ConnectionManager -import org.apache.spark.serializer.{Serializer, SerializerManager} +import org.apache.spark.serializer.Serializer import org.apache.spark.util.{AkkaUtils, Utils} /** @@ -41,7 +41,6 @@ import org.apache.spark.util.{AkkaUtils, Utils} class SparkEnv private[spark] ( val executorId: String, val actorSystem: ActorSystem, - val serializerManager: SerializerManager, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -53,7 +52,8 @@ class SparkEnv private[spark] ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val conf: SparkConf) extends Logging { + val conf: SparkConf, + val securityManager: SecurityManager) extends Logging { // A mapping of thread ID to amount of memory used for shuffle in bytes // All accesses should be manually synchronized @@ -122,8 +122,9 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean): SparkEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, - conf = conf) + val securityManager = new SecurityManager(conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, + securityManager = securityManager) // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), // figure out which port number Akka actually bound to and set spark.driver.port to it. @@ -137,17 +138,22 @@ object SparkEnv extends Logging { // defaultClassName if the property is not set, and return it as a T def instantiateClass[T](propertyName: String, defaultClassName: String): T = { val name = conf.get(propertyName, defaultClassName) - Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] + val cls = Class.forName(name, true, classLoader) + // First try with the constructor that takes SparkConf. If we can't find one, + // use a no-arg constructor instead. + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } } - val serializerManager = new SerializerManager - - val serializer = serializerManager.setDefault( - conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf) + val serializer = instantiateClass[Serializer]( + "spark.serializer", "org.apache.spark.serializer.JavaSerializer") - val closureSerializer = serializerManager.get( - conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"), - conf) + val closureSerializer = instantiateClass[Serializer]( + "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") def registerOrLookup(name: String, newActor: => Actor): ActorRef = { if (isDriver) { @@ -167,12 +173,12 @@ object SparkEnv extends Logging { val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf)), conf) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + serializer, conf, securityManager) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isDriver, conf) + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val cacheManager = new CacheManager(blockManager) @@ -185,19 +191,19 @@ object SparkEnv extends Logging { } mapOutputTracker.trackerActor = registerOrLookup( "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])) + new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer() + val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) val metricsSystem = if (isDriver) { - MetricsSystem.createMetricsSystem("driver", conf) + MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { - MetricsSystem.createMetricsSystem("executor", conf) + MetricsSystem.createMetricsSystem("executor", conf, securityManager) } metricsSystem.start() @@ -219,7 +225,6 @@ object SparkEnv extends Logging { new SparkEnv( executorId, actorSystem, - serializerManager, serializer, closureSerializer, cacheManager, @@ -231,6 +236,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, - conf) + conf, + securityManager) } } diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala new file mode 100644 index 0000000000000..a2a871cbd3c31 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.IOException +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.RealmCallback +import javax.security.sasl.RealmChoiceCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslClient +import javax.security.sasl.SaslException + +import scala.collection.JavaConversions.mapAsJavaMap + +/** + * Implements SASL Client logic for Spark + */ +private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { + + /** + * Used to respond to server's counterpart, SaslServer with SASL tokens + * represented as byte arrays. + * + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), + null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslClientCallbackHandler(securityMgr)) + + /** + * Used to initiate SASL handshake with server. + * @return response to challenge if needed + */ + def firstToken(): Array[Byte] = { + synchronized { + val saslToken: Array[Byte] = + if (saslClient != null && saslClient.hasInitialResponse()) { + logDebug("has initial response") + saslClient.evaluateChallenge(new Array[Byte](0)) + } else { + new Array[Byte](0) + } + saslToken + } + } + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + synchronized { + if (saslClient != null) saslClient.isComplete() else false + } + } + + /** + * Respond to server's SASL token. + * @param saslTokenMessage contains server's SASL token + * @return client's response SASL token + */ + def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { + synchronized { + if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + def dispose() { + synchronized { + if (saslClient != null) { + try { + saslClient.dispose() + } catch { + case e: SaslException => // ignored + } finally { + saslClient = null + } + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends + CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + private val secretKey = securityMgr.getSecretKey() + private val userPassword: Array[Char] = + SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes()) + + /** + * Implementation used to respond to SASL request from the server. + * + * @param callbacks objects that indicate what credential information the + * server's SaslServer requires from the client. + */ + override def handle(callbacks: Array[Callback]) { + logDebug("in the sasl client callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL client callback: setting username: " + userName) + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL client callback: setting userPassword") + pc.setPassword(userPassword) + } + case rc: RealmCallback => { + logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case cb: RealmChoiceCallback => {} + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala new file mode 100644 index 0000000000000..11fcb2ae3a5c5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.AuthorizeCallback +import javax.security.sasl.RealmCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslException +import javax.security.sasl.SaslServer +import scala.collection.JavaConversions.mapAsJavaMap +import org.apache.commons.net.util.Base64 + +/** + * Encapsulates SASL server logic + */ +private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { + + /** + * Actual SASL work done by this object from javax.security.sasl. + */ + private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, + SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslDigestCallbackHandler(securityMgr)) + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + synchronized { + if (saslServer != null) saslServer.isComplete() else false + } + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + def response(token: Array[Byte]): Array[Byte] = { + synchronized { + if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + def dispose() { + synchronized { + if (saslServer != null) { + try { + saslServer.dispose() + } catch { + case e: SaslException => // ignore + } finally { + saslServer = null + } + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * for SASL DIGEST-MD5 mechanism + */ + private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) + extends CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + + override def handle(callbacks: Array[Callback]) { + logDebug("In the sasl server callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL server callback: setting username") + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL server callback: setting userPassword") + val password: Array[Char] = + SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes()) + pc.setPassword(password) + } + case rc: RealmCallback => { + logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case ac: AuthorizeCallback => { + val authid = ac.getAuthenticationID() + val authzid = ac.getAuthorizationID() + if (authid.equals(authzid)) { + logDebug("set auth to true") + ac.setAuthorized(true) + } else { + logDebug("set auth to false") + ac.setAuthorized(false) + } + if (ac.isAuthorized()) { + logDebug("sasl server is authorized") + ac.setAuthorizedID(authzid) + } + } + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") + } + } + } +} + +private[spark] object SparkSaslServer { + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + val SASL_DEFAULT_REALM = "default" + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + val DIGEST = "DIGEST-MD5" + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") + + /** + * Encode a byte[] identifier as a Base64-encoded string. + * + * @param identifier identifier to encode + * @return Base64-encoded string + */ + def encodeIdentifier(identifier: Array[Byte]): String = { + new String(Base64.encodeBase64(identifier)) + } + + /** + * Encode a password as a base64-encoded char[] array. + * @param password as a byte array. + * @return password as a char array. + */ + def encodePassword(password: Array[Byte]): Array[Char] = { + new String(Base64.encodeBase64(password)).toCharArray() + } +} + diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index cae983ed4c652..be53ca2968cfb 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -46,6 +46,7 @@ class TaskContext( } def executeOnCompleteCallbacks() { - onCompleteCallbacks.foreach{_()} + // Process complete callbacks in the reverse order of registration + onCompleteCallbacks.reverse.foreach{_()} } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 071044463d980..f816bb43a5b44 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -83,7 +83,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja * Return a new RDD containing only the elements that satisfy a predicate. */ def filter(f: JFunction[JDouble, java.lang.Boolean]): JavaDoubleRDD = - fromRDD(srdd.filter(x => f(x).booleanValue())) + fromRDD(srdd.filter(x => f.call(x).booleanValue())) /** * Return a new RDD that is reduced into `numPartitions` partitions. @@ -140,6 +140,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja */ def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd)) + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) + // Double RDD functions /** Add up the elements in this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 3f672900cb90f..9596dbaf75488 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -26,13 +26,13 @@ import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} import org.apache.spark.storage.StorageLevel @@ -89,7 +89,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return a new RDD containing only the elements that satisfy a predicate. */ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue())) + new JavaPairRDD[K, V](rdd.filter(x => f.call(x).booleanValue())) /** * Return a new RDD that is reduced into `numPartitions` partitions. @@ -126,6 +126,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.union(other.rdd)) + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.intersection(other.rdd)) + + // first() has to be overridden here so that the generated method has the signature // 'public scala.Tuple2 first()'; if the trait's definition is used, // then the method has the signature 'public java.lang.Object first()', @@ -165,9 +175,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Simplified version of combineByKey that hash-partitions the output RDD. */ def combineByKey[C](createCombiner: JFunction[V, C], - mergeValue: JFunction2[C, V, C], - mergeCombiners: JFunction2[C, C, C], - numPartitions: Int): JavaPairRDD[K, C] = + mergeValue: JFunction2[C, V, C], + mergeCombiners: JFunction2[C, C, C], + numPartitions: Int): JavaPairRDD[K, C] = combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) /** @@ -442,7 +452,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ - def fn = (x: V) => f.apply(x).asScala + def fn = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) } @@ -511,49 +521,57 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - conf: JobConf) { + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + conf: JobConf) { rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) } /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F]) { + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F]) { rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) } /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - codec: Class[_ <: CompressionCodec]) { + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + codec: Class[_ <: CompressionCodec]) { rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec) } /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F], - conf: Configuration) { + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F], + conf: Configuration) { rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) } + /** + * Output the RDD to any Hadoop-supported storage system, using + * a Configuration object for that storage system. + */ + def saveAsNewAPIHadoopDataset(conf: Configuration) { + rdd.saveAsNewAPIHadoopDataset(conf) + } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( - path: String, - keyClass: Class[_], - valueClass: Class[_], - outputFormatClass: Class[F]) { + path: String, + keyClass: Class[_], + valueClass: Class[_], + outputFormatClass: Class[F]) { rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) } @@ -700,6 +718,15 @@ object JavaPairRDD { implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd + private[spark] + implicit def toScalaFunction2[T1, T2, R](fun: JFunction2[T1, T2, R]): Function2[T1, T2, R] = { + (x: T1, x1: T2) => fun.call(x, x1) + } + + private[spark] implicit def toScalaFunction[T, R](fun: JFunction[T, R]): T => R = x => fun.call(x) + + private[spark] + implicit def pairFunToScalaFun[A, B, C](x: PairFunction[A, B, C]): A => (B, C) = y => x.call(y) /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */ def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 0055c98844ded..01d9357a2556d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -70,7 +70,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return a new RDD containing only the elements that satisfy a predicate. */ def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] = - wrapRDD(rdd.filter((x => f(x).booleanValue()))) + wrapRDD(rdd.filter((x => f.call(x).booleanValue()))) /** * Return a new RDD that is reduced into `numPartitions` partitions. @@ -106,6 +106,15 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) + + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) + /** * Return an RDD with the elements from `this` that are not in `other`. * diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 24a9925dbd22c..05b89b985736d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -19,7 +19,6 @@ package org.apache.spark.api.java import java.util.{Comparator, List => JList} -import scala.Tuple2 import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -67,14 +66,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def map[R](f: JFunction[T, R]): JavaRDD[R] = - new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType()) + new JavaRDD(rdd.map(f)(fakeClassTag))(fakeClassTag) /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ def mapPartitionsWithIndex[R: ClassTag]( - f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), preservesPartitioning)) @@ -82,15 +81,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return a new RDD by applying a function to all elements of this RDD. */ - def map[R](f: DoubleFunction[T]): JavaDoubleRDD = - new JavaDoubleRDD(rdd.map(x => f(x).doubleValue())) + def mapToDouble[R](f: DoubleFunction[T]): JavaDoubleRDD = { + new JavaDoubleRDD(rdd.map(x => f.call(x).doubleValue())) + } /** * Return a new RDD by applying a function to all elements of this RDD. */ - def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - val ctag = implicitly[ClassTag[Tuple2[K2, V2]]] - new JavaPairRDD(rdd.map(f)(ctag))(f.keyType(), f.valueType()) + def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { + def cm = implicitly[ClassTag[(K2, V2)]] + new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } /** @@ -99,17 +99,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType()) + def fn = (x: T) => f.call(x).asScala + JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } /** * Return a new RDD by first applying a function to all elements of this * RDD, and then flattening the results. */ - def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { + def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala + def fn = (x: T) => f.call(x).asScala new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) } @@ -117,19 +117,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by first applying a function to all elements of this * RDD, and then flattening the results. */ - def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { + def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - val ctag = implicitly[ClassTag[Tuple2[K2, V2]]] - JavaPairRDD.fromRDD(rdd.flatMap(fn)(ctag))(f.keyType(), f.valueType()) + def fn = (x: T) => f.call(x).asScala + def cm = implicitly[ClassTag[(K2, V2)]] + JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } /** * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } /** @@ -137,52 +137,53 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaRDD.fromRDD(rdd.mapPartitions(fn, preservesPartitioning)(f.elementType()))(f.elementType()) + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + JavaRDD.fromRDD( + rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U]) } /** - * Return a new RDD by applying a function to each partition of this RDD. + * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) } /** - * Return a new RDD by applying a function to each partition of this RDD. + * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): + def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } - /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]], - preservesPartitioning: Boolean): JavaDoubleRDD = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) + def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], + preservesPartitioning: Boolean): JavaDoubleRDD = { + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) - .map((x: java.lang.Double) => x.doubleValue())) + .map(x => x.doubleValue())) } /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], + def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - JavaPairRDD.fromRDD(rdd.mapPartitions(fn, preservesPartitioning))(f.keyType(), f.valueType()) + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + JavaPairRDD.fromRDD( + rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2]) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { - rdd.foreachPartition((x => f(asJavaIterator(x)))) + rdd.foreachPartition((x => f.call(asJavaIterator(x)))) } /** @@ -205,7 +206,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[JList[T]] = fakeClassTag - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType))) + JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(fakeClassTag))) } /** @@ -215,7 +216,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[JList[T]] = fakeClassTag - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType))) + JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[K]))) } /** @@ -255,9 +256,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { other: JavaRDDLike[U, _], f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( - f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) + f.call(asJavaIterator(x), asJavaIterator(y)).iterator()) JavaRDD.fromRDD( - rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType()) + rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V]) } // Actions (launch a job to return a value to the user program) @@ -266,7 +267,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to all elements of this RDD. */ def foreach(f: VoidFunction[T]) { - val cleanF = rdd.context.clean(f) + val cleanF = rdd.context.clean((x: T) => f.call(x)) rdd.foreach(cleanF) } @@ -281,7 +282,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in this RDD. + * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead */ + @Deprecated def toArray(): JList[T] = collect() /** @@ -320,7 +323,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U], combOp: JFunction2[U, U, U]): U = - rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType) + rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U]) /** * Return the number of elements in the RDD. @@ -475,6 +478,26 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } + /** + * Returns the maximum element from this RDD as defined by the specified + * Comparator[T]. + * @params comp the comparator that defines ordering + * @return the maximum of the RDD + * */ + def max(comp: Comparator[T]): T = { + rdd.max()(Ordering.comparatorToOrdering(comp)) + } + + /** + * Returns the minimum element from this RDD as defined by the specified + * Comparator[T]. + * @params comp the comparator that defines ordering + * @return the minimum of the RDD + * */ + def min(comp: Comparator[T]): T = { + rdd.min()(Ordering.comparatorToOrdering(comp)) + } + /** * Returns the first K elements from this RDD using the * natural ordering for T while maintain the order. @@ -498,8 +521,4 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def name(): String = rdd.name - /** Reset generator */ - def setGenerator(_generator: String) = { - rdd.setGenerator(_generator) - } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index dc26b7f621fee..8e0eab56a3dcf 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.java +import java.util import java.util.{Map => JMap} import scala.collection.JavaConversions @@ -92,6 +93,24 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork private[spark] val env = sc.env + def isLocal: java.lang.Boolean = sc.isLocal + + def sparkUser: String = sc.sparkUser + + def master: String = sc.master + + def appName: String = sc.appName + + def jars: util.List[String] = sc.jars + + def startTime: java.lang.Long = sc.startTime + + /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ + def defaultParallelism: java.lang.Integer = sc.defaultParallelism + + /** Default min number of partitions for Hadoop RDDs when not given by user */ + def defaultMinSplits: java.lang.Integer = sc.defaultMinSplits + /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala deleted file mode 100644 index 2cdf2e92c3daa..0000000000000 --- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function - -import java.lang.{Double => JDouble} - -/** - * A function that returns Doubles, and can be used to construct DoubleRDDs. - */ -// DoubleFunction does not extend Function because some UDF functions, like map, -// are overloaded for both Function and DoubleFunction. -abstract class DoubleFunction[T] extends WrappedFunction1[T, JDouble] with Serializable { - // Intentionally left blank -} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala deleted file mode 100644 index 8467bbb892ab0..0000000000000 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function - -import java.lang.{Iterable => JIterable} -import org.apache.spark.api.java.JavaSparkContext -import scala.reflect.ClassTag - -/** - * A function that returns zero or more key-value pair records from each input record. The - * key-value pairs are represented as scala.Tuple2 objects. - */ -// PairFlatMapFunction does not extend FlatMapFunction because flatMap is -// overloaded for both FlatMapFunction and PairFlatMapFunction. -abstract class PairFlatMapFunction[T, K, V] extends WrappedFunction1[T, JIterable[(K, V)]] - with Serializable { - - def keyType(): ClassTag[K] = JavaSparkContext.fakeClassTag - - def valueType(): ClassTag[V] = JavaSparkContext.fakeClassTag -} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala deleted file mode 100644 index cfe694f65d558..0000000000000 --- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function - -import scala.runtime.AbstractFunction1 - -/** - * Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the - * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply - * isn't marked to allow that). - */ -private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] { - @throws(classOf[Exception]) - def call(t: T): R - - final def apply(t: T): R = call(t) -} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala deleted file mode 100644 index eb9277c6fb4cb..0000000000000 --- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function - -import scala.runtime.AbstractFunction2 - -/** - * Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the - * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply - * isn't marked to allow that). - */ -private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] { - @throws(classOf[Exception]) - def call(t1: T1, t2: T2): R - - final def apply(t1: T1, t2: T2): R = call(t1, t2) -} diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala deleted file mode 100644 index d314dbdf1d980..0000000000000 --- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function - -import scala.runtime.AbstractFunction3 - -/** - * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the - * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply - * isn't marked to allow that). - */ -private[spark] abstract class WrappedFunction3[T1, T2, T3, R] - extends AbstractFunction3[T1, T2, T3, R] { - @throws(classOf[Exception]) - def call(t1: T1, t2: T2, t3: T3): R - - final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3) -} - diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ca0e702addddd..43631e0e3bd49 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -103,6 +103,14 @@ private[spark] class PythonRDD[T: ClassTag]( } }.start() + /* + * Partial fix for SPARK-1019: Attempts to stop reading the input stream since + * other completion callbacks might invalidate the input. Because interruption + * is not synchronous this still leaves a potential race where the interruption is + * processed only after the stream becomes invalid. + */ + context.addOnCompleteCallback(() => context.interrupted = true) + // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = new Iterator[Array[Byte]] { diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d113d4040594d..e3c3a12d16f2a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) + extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -78,7 +79,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf) + broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 940e5ab805100..6beecaeced5be 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.broadcast +import org.apache.spark.SecurityManager import org.apache.spark.SparkConf @@ -26,7 +27,7 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 20207c261320b..e8eb04bb10469 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -18,13 +18,13 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.URL +import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream -import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv} +import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. */ class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging { private var bufferSize: Int = 65536 private var serverUri: String = null private var server: HttpServer = null + private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] @@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging { private var compressionCodec: CompressionCodec = null - def initialize(isDriver: Boolean, conf: SparkConf) { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { synchronized { if (!initialized) { bufferSize = conf.getInt("spark.buffer.size", 65536) compress = conf.getBoolean("spark.broadcast.compress", true) + securityManager = securityMgr if (isDriver) { createServer(conf) conf.set("spark.httpBroadcast.uri", serverUri) @@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) - server = new HttpServer(broadcastDir) + server = new HttpServer(broadcastDir, securityManager) server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) @@ -149,11 +153,23 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name + + var uc: URLConnection = null + if (securityManager.isAuthenticationEnabled()) { + logDebug("broadcast security enabled") + val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + } else { + logDebug("broadcast not using security") + uc = new URL(url).openConnection() + } + val in = { - val httpConnection = new URL(url).openConnection() - httpConnection.setReadTimeout(httpReadTimeout) - val inputStream = httpConnection.getInputStream + uc.setReadTimeout(httpReadTimeout) + val inputStream = uc.getInputStream(); if (compress) { compressionCodec.compressedInputStream(inputStream) } else { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d783c8590c6..3cd71213769b7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -241,7 +241,9 @@ private[spark] case class TorrentInfo( */ class TorrentBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index eb5676b51d836..d9e3035e1ab59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -26,7 +26,7 @@ import akka.pattern.ask import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.util.{AkkaUtils, Utils} @@ -141,7 +141,7 @@ object Client { // TODO: See if we can initialize akka so return messages are sent back using the same TCP // flow. Else, this (sadly) requires the DriverClient be routable from the Master. val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, false, conf) + "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 190b331cfe7d8..f4eb1601be3e4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -27,22 +27,27 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.sys.process._ -import net.liftweb.json.JsonParser +import org.json4s._ +import org.json4s.jackson.JsonMethods -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.deploy.master.RecoveryState +import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} /** * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. * In order to mimic a real distributed cluster more closely, Docker is used. * Execute using - * ./spark-class org.apache.spark.deploy.FaultToleranceTest + * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest * - * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS: + * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS + * *and* SPARK_JAVA_OPTS: * - spark.deploy.recoveryMode=ZOOKEEPER * - spark.deploy.zookeeper.url=172.17.42.1:2181 * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port. * + * In case of failure, make sure to kill off prior docker containers before restarting: + * docker kill $(docker ps -q) + * * Unfortunately, due to the Docker dependency this suite cannot be run automatically without a * working installation of Docker. In addition to having Docker, the following are assumed: * - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/) @@ -50,10 +55,16 @@ import org.apache.spark.deploy.master.RecoveryState * docker/ directory. Run 'docker/spark-test/build' to generate these. */ private[spark] object FaultToleranceTest extends App with Logging { + + val conf = new SparkConf() + val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + val masters = ListBuffer[TestMasterInfo]() val workers = ListBuffer[TestWorkerInfo]() var sc: SparkContext = _ + val zk = SparkCuratorUtil.newClient(conf) + var numPassed = 0 var numFailed = 0 @@ -71,6 +82,10 @@ private[spark] object FaultToleranceTest extends App with Logging { sc = null } terminateCluster() + + // Clear ZK directories in between tests (for speed purposes) + SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/spark_leader") + SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/master_status") } test("sanity-basic") { @@ -167,26 +182,34 @@ private[spark] object FaultToleranceTest extends App with Logging { try { fn numPassed += 1 + logInfo("==============================================") logInfo("Passed: " + name) + logInfo("==============================================") } catch { case e: Exception => numFailed += 1 + logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") logError("FAILED: " + name, e) + logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + sys.exit(1) } afterEach() } def addMasters(num: Int) { + logInfo(s">>>>> ADD MASTERS $num <<<<<") (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) } } def addWorkers(num: Int) { + logInfo(s">>>>> ADD WORKERS $num <<<<<") val masterUrls = getMasterUrls(masters) (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) } } /** Creates a SparkContext, which constructs a Client to interact with our cluster. */ def createClient() = { + logInfo(">>>>> CREATE CLIENT <<<<<") if (sc != null) { sc.stop() } // Counter-hack: Because of a hack in SparkEnv#create() that changes this // property, we need to reset it. @@ -205,6 +228,7 @@ private[spark] object FaultToleranceTest extends App with Logging { } def killLeader(): Unit = { + logInfo(">>>>> KILL LEADER <<<<<") masters.foreach(_.readState()) val leader = getLeader masters -= leader @@ -214,6 +238,7 @@ private[spark] object FaultToleranceTest extends App with Logging { def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis) def terminateCluster() { + logInfo(">>>>> TERMINATE CLUSTER <<<<<") masters.foreach(_.kill()) workers.foreach(_.kill()) masters.clear() @@ -244,6 +269,7 @@ private[spark] object FaultToleranceTest extends App with Logging { * are all alive in a proper configuration (e.g., only one leader). */ def assertValidClusterState() = { + logInfo(">>>>> ASSERT VALID CLUSTER STATE <<<<<") assertUsable() var numAlive = 0 var numStandby = 0 @@ -311,7 +337,7 @@ private[spark] object FaultToleranceTest extends App with Logging { private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File) extends Logging { - implicit val formats = net.liftweb.json.DefaultFormats + implicit val formats = org.json4s.DefaultFormats var state: RecoveryState.Value = _ var liveWorkerIPs: List[String] = _ var numLiveApps = 0 @@ -321,11 +347,15 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val def readState() { try { val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream) - val json = JsonParser.parse(masterStream, closeAutomatically = true) + val json = JsonMethods.parse(masterStream) val workers = json \ "workers" val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE") - liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String]) + // Extract the worker IP from "webuiaddress" (rather than "host") because the host name + // on containers is a weird hash instead of the actual IP address. + liveWorkerIPs = liveWorkers.map { + w => (w \ "webuiaddress").extract[String].stripPrefix("http://").stripSuffix(":8081") + } numLiveApps = (json \ "activeapps").children.size @@ -349,7 +379,7 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File) extends Logging { - implicit val formats = net.liftweb.json.DefaultFormats + implicit val formats = org.json4s.DefaultFormats logDebug("Created worker: " + this) @@ -402,7 +432,7 @@ private[spark] object Docker extends Logging { def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = { val mountCmd = if (mountDir != "") { " -v " + mountDir } else "" - val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args) + val cmd = "docker run -privileged %s %s %s".format(mountCmd, imageTag, args) logDebug("Run command: " + cmd) cmd } diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 318beb5db5214..cefb1ff97e83c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy -import net.liftweb.json.JsonDSL._ +import org.json4s.JsonDSL._ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index b479225b45ee9..d2d8d6d662d55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,10 +21,13 @@ import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkContext, SparkException} +import scala.collection.JavaConversions._ + /** * Contains util methods to interact with Hadoop from Spark. */ @@ -33,15 +36,9 @@ class SparkHadoopUtil { UserGroupInformation.setConfiguration(conf) def runAsUser(user: String)(func: () => Unit) { - // if we are already running as the user intended there is no reason to do the doAs. It - // will actually break secure HDFS access as it doesn't fill in the credentials. Also if - // the user is UNKNOWN then we shouldn't be creating a remote unknown user - // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only - // in SparkContext. - val currentUser = Option(System.getProperty("user.name")). - getOrElse(SparkContext.SPARK_UNKNOWN_USER) - if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) { + if (user != SparkContext.SPARK_UNKNOWN_USER) { val ugi = UserGroupInformation.createRemoteUser(user) + transferCredentials(UserGroupInformation.getCurrentUser(), ugi) ugi.doAs(new PrivilegedExceptionAction[Unit] { def run: Unit = func() }) @@ -50,6 +47,12 @@ class SparkHadoopUtil { } } + def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { + for (token <- source.getTokens()) { + dest.addToken(token) + } + } + /** * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop * subsystems. @@ -63,6 +66,15 @@ class SparkHadoopUtil { def addCredentials(conf: JobConf) {} def isYarnMode(): Boolean = { false } + + def getCurrentUserCredentials(): Credentials = { null } + + def addCurrentUserCredentials(creds: Credentials) {} + + def addSecretKeyToUserCredentials(key: String, secret: String) {} + + def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 1550c3eb4286b..63f166d401059 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.client -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.util.{AkkaUtils, Utils} @@ -45,8 +45,9 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) + val conf = new SparkConf val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, - conf = new SparkConf) + conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription( "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), Some("dummy-spark-home"), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index f25a1ad3bf92a..a730fe1f599af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -30,6 +30,7 @@ import org.apache.spark.deploy.master.MasterMessages.ElectedLeader * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]] */ private[spark] trait LeaderElectionAgent extends Actor { + //TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring. val masterActor: ActorRef } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 51794ce40cb45..b8dfa44102583 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,7 +30,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState.DriverState @@ -39,7 +39,8 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} -private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher // to use Akka's scheduler.schedule() val conf = new SparkConf @@ -70,8 +71,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act Utils.checkHost(host, "Expected hostname") - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf) - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf) + val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) + val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, + securityMgr) val masterSource = new MasterSource(this) val webUi = new MasterWebUI(this, webUiPort) @@ -529,8 +531,15 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val workerAddress = worker.actor.path.address if (addressToWorker.contains(workerAddress)) { - logInfo("Attempted to re-register worker at same address: " + workerAddress) - return false + val oldWorker = addressToWorker(workerAddress) + if (oldWorker.state == WorkerState.UNKNOWN) { + // A worker registering from UNKNOWN implies that the worker was restarted during recovery. + // The old worker must thus be dead, so we will remove it and accept the new worker. + removeWorker(oldWorker) + } else { + logInfo("Attempted to re-register worker at same address: " + workerAddress) + return false + } } workers += worker @@ -711,8 +720,11 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf) : (ActorSystem, Int, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName) + val securityMgr = new SecurityManager(conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, + securityManager = securityMgr) + val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, + securityMgr), actorName) val timeout = AkkaUtils.askTimeout(conf) val respFuture = actor.ask(RequestWebUIPort)(timeout) val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 74a9f8cd824fb..db72d8ae9bdaf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,10 +28,6 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to LeaderElectionAgent - - case object CheckLeader - // Actor System to Master case object CheckForWorkerTimeOut diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala new file mode 100644 index 0000000000000..4781a80d470e1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import scala.collection.JavaConversions._ + +import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory} +import org.apache.curator.retry.ExponentialBackoffRetry +import org.apache.zookeeper.KeeperException + +import org.apache.spark.{Logging, SparkConf} + +object SparkCuratorUtil extends Logging { + + val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 + val ZK_SESSION_TIMEOUT_MILLIS = 60000 + val RETRY_WAIT_MILLIS = 5000 + val MAX_RECONNECT_ATTEMPTS = 3 + + def newClient(conf: SparkConf): CuratorFramework = { + val ZK_URL = conf.get("spark.deploy.zookeeper.url") + val zk = CuratorFrameworkFactory.newClient(ZK_URL, + ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS, + new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS)) + zk.start() + zk + } + + def mkdir(zk: CuratorFramework, path: String) { + if (zk.checkExists().forPath(path) == null) { + try { + zk.create().creatingParentsIfNeeded().forPath(path) + } catch { + case nodeExist: KeeperException.NodeExistsException => + // do nothing, ignore node existing exception. + case e: Exception => throw e + } + } + } + + def deleteRecursive(zk: CuratorFramework, path: String) { + if (zk.checkExists().forPath(path) != null) { + for (child <- zk.getChildren.forPath(path)) { + zk.delete().forPath(path + "/" + child) + } + zk.delete().forPath(path) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala deleted file mode 100644 index 57758055b19c0..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.master - -import scala.collection.JavaConversions._ - -import org.apache.zookeeper._ -import org.apache.zookeeper.Watcher.Event.KeeperState -import org.apache.zookeeper.data.Stat - -import org.apache.spark.{Logging, SparkConf} - -/** - * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry - * logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be - * created. If ZooKeeper remains down after several retries, the given - * [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be - * informed via zkDown(). - * - * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many - * times or a semantic exception is thrown (e.g., "node already exists"). - */ -private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher, - conf: SparkConf) extends Logging { - val ZK_URL = conf.get("spark.deploy.zookeeper.url", "") - - val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE - val ZK_TIMEOUT_MILLIS = 30000 - val RETRY_WAIT_MILLIS = 5000 - val ZK_CHECK_PERIOD_MILLIS = 10000 - val MAX_RECONNECT_ATTEMPTS = 3 - - private var zk: ZooKeeper = _ - - private val watcher = new ZooKeeperWatcher() - private var reconnectAttempts = 0 - private var closed = false - - /** Connect to ZooKeeper to start the session. Must be called before anything else. */ - def connect() { - connectToZooKeeper() - - new Thread() { - override def run() = sessionMonitorThread() - }.start() - } - - def sessionMonitorThread(): Unit = { - while (!closed) { - Thread.sleep(ZK_CHECK_PERIOD_MILLIS) - if (zk.getState != ZooKeeper.States.CONNECTED) { - reconnectAttempts += 1 - val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts - if (attemptsLeft <= 0) { - logError("Could not connect to ZooKeeper: system failure") - zkWatcher.zkDown() - close() - } else { - logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...") - connectToZooKeeper() - } - } - } - } - - def close() { - if (!closed && zk != null) { zk.close() } - closed = true - } - - private def connectToZooKeeper() { - if (zk != null) zk.close() - zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher) - } - - /** - * Attempts to maintain a live ZooKeeper exception despite (very) transient failures. - * Mainly useful for handling the natural ZooKeeper session expiration. - */ - private class ZooKeeperWatcher extends Watcher { - def process(event: WatchedEvent) { - if (closed) { return } - - event.getState match { - case KeeperState.SyncConnected => - reconnectAttempts = 0 - zkWatcher.zkSessionCreated() - case KeeperState.Expired => - connectToZooKeeper() - case KeeperState.Disconnected => - logWarning("ZooKeeper disconnected, will retry...") - case s => // Do nothing - } - } - } - - def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = { - retry { - zk.create(path, bytes, ZK_ACL, createMode) - } - } - - def exists(path: String, watcher: Watcher = null): Stat = { - retry { - zk.exists(path, watcher) - } - } - - def getChildren(path: String, watcher: Watcher = null): List[String] = { - retry { - zk.getChildren(path, watcher).toList - } - } - - def getData(path: String): Array[Byte] = { - retry { - zk.getData(path, false, null) - } - } - - def delete(path: String, version: Int = -1): Unit = { - retry { - zk.delete(path, version) - } - } - - /** - * Creates the given directory (non-recursively) if it doesn't exist. - * All znodes are created in PERSISTENT mode with no data. - */ - def mkdir(path: String) { - if (exists(path) == null) { - try { - create(path, "".getBytes, CreateMode.PERSISTENT) - } catch { - case e: Exception => - // If the exception caused the directory not to be created, bubble it up, - // otherwise ignore it. - if (exists(path) == null) { throw e } - } - } - } - - /** - * Recursively creates all directories up to the given one. - * All znodes are created in PERSISTENT mode with no data. - */ - def mkdirRecursive(path: String) { - var fullDir = "" - for (dentry <- path.split("/").tail) { - fullDir += "/" + dentry - mkdir(fullDir) - } - } - - /** - * Retries the given function up to 3 times. The assumption is that failure is transient, - * UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist), - * in which case the exception will be thrown without retries. - * - * @param fn Block to execute, possibly multiple times. - */ - def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = { - try { - fn - } catch { - case e: KeeperException.NoNodeException => throw e - case e: KeeperException.NodeExistsException => throw e - case e: Exception if n > 0 => - logError("ZooKeeper exception, " + n + " more retries...", e) - Thread.sleep(RETRY_WAIT_MILLIS) - retry(fn, n-1) - } - } -} - -trait SparkZooKeeperWatcher { - /** - * Called whenever a ZK session is created -- - * this will occur when we create our first session as well as each time - * the session expires or errors out. - */ - def zkSessionCreated() - - /** - * Called if ZK appears to be completely down (i.e., not just a transient error). - * We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead. - */ - def zkDown() -} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 47b8f67f8a45b..285f9b014e291 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -18,105 +18,67 @@ package org.apache.spark.deploy.master import akka.actor.ActorRef -import org.apache.zookeeper._ -import org.apache.zookeeper.Watcher.Event.EventType import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ +import org.apache.curator.framework.CuratorFramework +import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String, conf: SparkConf) - extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging { + extends LeaderElectionAgent with LeaderLatchListener with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" - private val watcher = new ZooKeeperWatcher() - private val zk = new SparkZooKeeperSession(this, conf) + private var zk: CuratorFramework = _ + private var leaderLatch: LeaderLatch = _ private var status = LeadershipStatus.NOT_LEADER - private var myLeaderFile: String = _ - private var leaderUrl: String = _ override def preStart() { + logInfo("Starting ZooKeeper LeaderElection agent") - zk.connect() - } + zk = SparkCuratorUtil.newClient(conf) + leaderLatch = new LeaderLatch(zk, WORKING_DIR) + leaderLatch.addListener(this) - override def zkSessionCreated() { - synchronized { - zk.mkdirRecursive(WORKING_DIR) - myLeaderFile = - zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL) - self ! CheckLeader - } + leaderLatch.start() } override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) { - logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason) - Thread.sleep(zk.ZK_TIMEOUT_MILLIS) + logError("LeaderElectionAgent failed...", reason) super.preRestart(reason, message) } - override def zkDown() { - logError("ZooKeeper down! LeaderElectionAgent shutting down Master.") - System.exit(1) - } - override def postStop() { + leaderLatch.close() zk.close() } override def receive = { - case CheckLeader => checkLeader() + case _ => } - private class ZooKeeperWatcher extends Watcher { - def process(event: WatchedEvent) { - if (event.getType == EventType.NodeDeleted) { - logInfo("Leader file disappeared, a master is down!") - self ! CheckLeader + override def isLeader() { + synchronized { + // could have lost leadership by now. + if (!leaderLatch.hasLeadership) { + return } - } - } - /** Uses ZK leader election. Navigates several ZK potholes along the way. */ - def checkLeader() { - val masters = zk.getChildren(WORKING_DIR).toList - val leader = masters.sorted.head - val leaderFile = WORKING_DIR + "/" + leader - - // Setup a watch for the current leader. - zk.exists(leaderFile, watcher) - - try { - leaderUrl = new String(zk.getData(leaderFile)) - } catch { - // A NoNodeException may be thrown if old leader died since the start of this method call. - // This is fine -- just check again, since we're guaranteed to see the new values. - case e: KeeperException.NoNodeException => - logInfo("Leader disappeared while reading it -- finding next leader") - checkLeader() - return + logInfo("We have gained leadership") + updateLeadershipStatus(true) } + } - // Synchronization used to ensure no interleaving between the creation of a new session and the - // checking of a leader, which could cause us to delete our real leader file erroneously. + override def notLeader() { synchronized { - val isLeader = myLeaderFile == leaderFile - if (!isLeader && leaderUrl == masterUrl) { - // We found a different master file pointing to this process. - // This can happen in the following two cases: - // (1) The master process was restarted on the same node. - // (2) The ZK server died between creating the file and returning the name of the file. - // For this case, we will end up creating a second file, and MUST explicitly delete the - // first one, since our ZK session is still open. - // Note that this deletion will cause a NodeDeleted event to be fired so we check again for - // leader changes. - assert(leaderFile < myLeaderFile) - logWarning("Cleaning up old ZK master election file that points to this master.") - zk.delete(leaderFile) - } else { - updateLeadershipStatus(isLeader) + // could have gained leadership by now. + if (leaderLatch.hasLeadership) { + return } + + logInfo("We have lost leadership") + updateLeadershipStatus(false) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 48b2fc06a9d70..5413ff671ad8d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,36 +17,28 @@ package org.apache.spark.deploy.master +import scala.collection.JavaConversions._ + import akka.serialization.Serialization -import org.apache.zookeeper._ +import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) extends PersistenceEngine - with SparkZooKeeperWatcher with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" + val zk = SparkCuratorUtil.newClient(conf) - val zk = new SparkZooKeeperSession(this, conf) - - zk.connect() - - override def zkSessionCreated() { - zk.mkdirRecursive(WORKING_DIR) - } - - override def zkDown() { - logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.") - } + SparkCuratorUtil.mkdir(zk, WORKING_DIR) override def addApplication(app: ApplicationInfo) { serializeIntoFile(WORKING_DIR + "/app_" + app.id, app) } override def removeApplication(app: ApplicationInfo) { - zk.delete(WORKING_DIR + "/app_" + app.id) + zk.delete().forPath(WORKING_DIR + "/app_" + app.id) } override def addDriver(driver: DriverInfo) { @@ -54,7 +46,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) } override def removeDriver(driver: DriverInfo) { - zk.delete(WORKING_DIR + "/driver_" + driver.id) + zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id) } override def addWorker(worker: WorkerInfo) { @@ -62,7 +54,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) } override def removeWorker(worker: WorkerInfo) { - zk.delete(WORKING_DIR + "/worker_" + worker.id) + zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id) } override def close() { @@ -70,26 +62,34 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) } override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted + val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted val appFiles = sortedFiles.filter(_.startsWith("app_")) - val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) + val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten val driverFiles = sortedFiles.filter(_.startsWith("driver_")) - val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) + val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten val workerFiles = sortedFiles.filter(_.startsWith("worker_")) - val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) + val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten (apps, drivers, workers) } private def serializeIntoFile(path: String, value: AnyRef) { val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) - zk.create(path, serialized, CreateMode.PERSISTENT) + zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) } - def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = { - val fileData = zk.getData(WORKING_DIR + "/" + filename) + def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = { + val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) - serializer.fromBinary(fileData).asInstanceOf[T] + try { + Some(serializer.fromBinary(fileData).asInstanceOf[T]) + } catch { + case e: Exception => { + logWarning("Exception while reading persisted file, deleting", e) + zk.delete().forPath(WORKING_DIR + "/" + filename) + None + } + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 5cc4adbe448b7..90cad3c37fda6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -23,7 +23,8 @@ import scala.concurrent.Await import scala.xml.Node import akka.pattern.ask -import net.liftweb.json.JsonAST.JValue +import javax.servlet.http.HttpServletRequest +import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala index 01c8f9065e50a..3233cd97f7bd0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala @@ -23,7 +23,8 @@ import scala.concurrent.Await import scala.xml.Node import akka.pattern.ask -import net.liftweb.json.JsonAST.JValue +import javax.servlet.http.HttpServletRequest +import org.json4s.JValue import org.apache.spark.deploy.{DeployWebUI, JsonProtocol} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} @@ -85,6 +86,7 @@ private[spark] class IndexPage(parent: MasterWebUI) {
  • Drivers: {state.activeDrivers.size} Running, {state.completedDrivers.size} Completed
  • +
  • Status: {state.status}
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 5ab13e7aa6b1f..4ad1f95be31c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -18,8 +18,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.Logging import org.apache.spark.deploy.master.Master @@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get)) @@ -60,12 +60,17 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++ master.applicationMetricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)), - ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), - ("/app", (request: HttpServletRequest) => applicationPage.render(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR + "/static", "/static"), + createServletHandler("/app/json", + createServlet((request: HttpServletRequest) => applicationPage.renderJson(request), + master.securityMgr)), + createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage + .render(request), master.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), master.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), master.securityMgr)) ) def stop() { @@ -74,5 +79,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { } private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index a26e47950a0ec..be15138f62406 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import akka.actor._ -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{AkkaUtils, Utils} /** @@ -29,8 +29,9 @@ object DriverWrapper { def main(args: Array[String]) { args.toList match { case workerUrl :: mainClass :: extraArgs => + val conf = new SparkConf() val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", - Utils.localHostName(), 0, false, new SparkConf()) + Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") // Delegate to supplied main class diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7b0b7861b76e1..afaabedffefea 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -27,7 +27,7 @@ import scala.concurrent.duration._ import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} @@ -48,7 +48,8 @@ private[spark] class Worker( actorSystemName: String, actorName: String, workDirPath: String = null, - val conf: SparkConf) + val conf: SparkConf, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher @@ -91,7 +92,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 - val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) def coresFree: Int = cores - coresUsed @@ -347,10 +348,11 @@ private[spark] object Worker { val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" + val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf) + conf = conf, securityManager = securityMgr) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, systemName, actorName, workDir, conf), name = actorName) + masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala index 3089acffb8d98..85200ab0e102d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -22,7 +22,7 @@ import scala.xml.Node import akka.pattern.ask import javax.servlet.http.HttpServletRequest -import net.liftweb.json.JsonAST.JValue +import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index bdf126f93abc8..4e33b330ad4e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker @@ -33,7 +33,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) - extends Logging { + extends Logging { val timeout = AkkaUtils.askTimeout(worker.conf) val host = Utils.localHostName() val port = requestedPort.getOrElse( @@ -46,17 +46,21 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val metricsHandlers = worker.metricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)), - ("/log", (request: HttpServletRequest) => log(request)), - ("/logPage", (request: HttpServletRequest) => logPage(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE + "/static", "/static"), + createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request), + worker.securityMgr)), + createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage + (request), worker.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), worker.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), worker.securityMgr)) ) def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Worker web UI at http://%s:%d".format(host, bPort)) @@ -198,6 +202,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I } private[spark] object WorkerWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_BASE = "org/apache/spark/ui" val DEFAULT_PORT="8081" } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 0aae569b17272..3486092a140fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor._ import akka.remote._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend { // Debug code Utils.checkHost(hostname) + val conf = new SparkConf // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - indestructible = true, conf = new SparkConf) + indestructible = true, conf = conf, new SecurityManager(conf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 989d666f15600..2ea2ec29f59f5 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} /** * Spark executor used with Mesos, YARN, and the standalone scheduler. @@ -69,11 +69,6 @@ private[spark] class Executor( conf.set("spark.local.dir", getYarnLocalDirs()) } - // Create our ClassLoader and set it on this thread - private val urlClassLoader = createClassLoader() - private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - Thread.currentThread.setContextClassLoader(replClassLoader) - if (!isLocal) { // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire @@ -117,11 +112,15 @@ private[spark] class Executor( } } + // Create our ClassLoader and set it on this thread + // do this after SparkEnv creation so can access the SecurityManager + private val urlClassLoader = createClassLoader() + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. - private val akkaFrameSize = { - env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size") - } + private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -338,12 +337,12 @@ private[spark] class Executor( // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 455339943f42d..760458cb02a9b 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -103,13 +103,6 @@ class ShuffleReadMetrics extends Serializable { */ var fetchWaitTime: Long = _ - /** - * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all - * input blocks. Since block fetches are both pipelined and parallelized, this can - * exceed fetchWaitTime and executorRunTime. - */ - var remoteFetchTime: Long = _ - /** * Total number of remote bytes read from the shuffle by this task */ diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 966c092124266..c5bda2078fc14 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source @@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source * [options] is the specific property of this source or sink. */ private[spark] class MetricsSystem private (val instance: String, - conf: SparkConf) extends Logging { + conf: SparkConf, securityMgr: SecurityManager) extends Logging { val confFile = conf.get("spark.metrics.conf", null) val metricsConfig = new MetricsConfig(Option(confFile)) @@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String, val classPath = kv._2.getProperty("class") try { val sink = Class.forName(classPath) - .getConstructor(classOf[Properties], classOf[MetricRegistry]) - .newInstance(kv._2, registry) + .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) + .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) } else { @@ -160,6 +160,7 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem = - new MetricsSystem(instance, conf) + def createMetricsSystem(instance: String, conf: SparkConf, + securityMgr: SecurityManager): MetricsSystem = + new MetricsSystem(instance, conf, securityMgr) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 98fa1dbd7c6ab..4d2ffc54d8983 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class ConsoleSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CONSOLE_DEFAULT_PERIOD = 10 val CONSOLE_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 40f64768e6885..319f40815d65f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.{CsvReporter, MetricRegistry} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class CsvSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CSV_KEY_PERIOD = "period" val CSV_KEY_UNIT = "unit" val CSV_KEY_DIR = "directory" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index e09be001421fc..0ffdf3846dc4a 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry import com.codahale.metrics.graphite.{Graphite, GraphiteReporter} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GraphiteSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GRAPHITE_DEFAULT_PERIOD = 10 val GRAPHITE_DEFAULT_UNIT = "SECONDS" val GRAPHITE_DEFAULT_PREFIX = "" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index b5cf210af2119..3b5edd5c376f0 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink import java.util.Properties import com.codahale.metrics.{JmxReporter, MetricRegistry} +import org.apache.spark.SecurityManager + +class JmxSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { -class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink { val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() override def start() { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 3cdfe26d40f66..3110eccdee4fc 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit + import javax.servlet.http.HttpServletRequest import com.codahale.metrics.MetricRegistry import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler +import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils -class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink { +class MetricsServlet(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -42,8 +45,11 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[(String, Handler)]( - (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) + def getHandlers = Array[ServletContextHandler]( + JettyUtils.createServletHandler(servletPath, + JettyUtils.createServlet( + new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"), + securityMgr) ) ) def getMetricsSnapshot(request: HttpServletRequest): String = { diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index d3c09b16063d6..04df2f3b0d696 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Max chunk size is " + maxChunkSize) } + val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Attempting to get chunk from message with multiple data buffers") } val buffer = buffers(0) + val security = if (isSecurityNeg) 1 else 0 if (buffer.remaining > 0) { if (buffer.remaining < chunkSize) { throw new Exception("Not enough space in data buffer for receiving chunk") @@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index f2e3c1a14ecc6..8fd9c2b87d256 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -17,6 +17,11 @@ package org.apache.spark.network +import org.apache.spark._ +import org.apache.spark.SparkSaslServer + +import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} + import java.net._ import java.nio._ import java.nio.channels._ @@ -27,13 +32,16 @@ import org.apache.spark._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) + val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) extends Logging { - def this(channel_ : SocketChannel, selector_ : Selector) = { + var sparkSaslServer: SparkSaslServer = null + var sparkSaslClient: SparkSaslClient = null + + def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_) } channel.configureBlocking(false) @@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val remoteAddress = getRemoteAddress() + /** + * Used to synchronize client requests: client's work-related requests must + * wait until SASL authentication completes. + */ + private val authenticated = new Object() + + def getAuthenticated(): Object = authenticated + + def isSaslComplete(): Boolean + def resetForceReregister(): Boolean // Read channels typically do not register for write and write does not for read @@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, // Will be true for ReceivingConnection, false for SendingConnection. def changeInterestForRead(): Boolean + private def disposeSasl() { + if (sparkSaslServer != null) { + sparkSaslServer.dispose(); + } + + if (sparkSaslClient != null) { + sparkSaslClient.dispose() + } + } + // On receiving a write event, should we change the interest for this channel or not ? // Will be false for ReceivingConnection, true for SendingConnection. // Actually, for now, should not get triggered for ReceivingConnection @@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, k.cancel() } channel.close() + disposeSasl() callOnCloseCallback() } @@ -168,10 +197,14 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) - extends Connection(SocketChannel.open, selector_, remoteId_) { + remoteId_ : ConnectionManagerId, id_ : ConnectionId) + extends Connection(SocketChannel.open, selector_, remoteId_, id_) { - private class Outbox(fair: Int = 0) { + def isSaslComplete(): Boolean = { + if (sparkSaslClient != null) sparkSaslClient.isComplete() else false + } + + private class Outbox { val messages = new Queue[Message]() val defaultChunkSize = 65536 //32768 //16384 var nextMessageToBeUsed = 0 @@ -186,38 +219,6 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } def getChunk(): Option[MessageChunk] = { - fair match { - case 0 => getChunkFIFO() - case 1 => getChunkRR() - case _ => throw new Exception("Unexpected fairness policy in outbox") - } - } - - private def getChunkFIFO(): Option[MessageChunk] = { - /*logInfo("Using FIFO")*/ - messages.synchronized { - while (!messages.isEmpty) { - val message = messages(0) - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) { - logDebug("Starting to send [" + message + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - return chunk - } else { - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + - "] in " + message.timeTaken ) - } - } - } - None - } - - private def getChunkRR(): Option[MessageChunk] = { messages.synchronized { while (!messages.isEmpty) { /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ @@ -249,7 +250,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // outbox is used as a lock - ensure that it is always used as a leaf (since methods which // lock it are invoked in context of other locks) - private val outbox = new Outbox(1) + private val outbox = new Outbox() /* This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly different purpose. This flag is to see if we need to force reregister for write even when we @@ -258,6 +259,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, data as detailed in https://github.com/mesos/spark/pull/791 */ private var needForceReregister = false + val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ @@ -348,6 +350,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // If we have 'seen' pending messages, then reset flag - since we handle that as // normal registering of event (below) if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() + currentBuffers ++= buffers } case None => { @@ -416,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) - extends Connection(channel_, selector_) { +private[spark] class ReceivingConnection( + channel_ : SocketChannel, + selector_ : Selector, + id_ : ConnectionId) + extends Connection(channel_, selector_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslServer != null) sparkSaslServer.isComplete() else false + } class Inbox() { val messages = new HashMap[Int, BufferMessage]() @@ -428,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis + newMessage.isSecurityNeg = header.securityNeg == 1 logDebug( "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) @@ -473,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection , Message) => Unit = null + var onReceiveCallback: (Connection, Message) => Unit = null var currentChunk: MessageChunk = null channel.register(selector, SelectionKey.OP_READ) @@ -548,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() diff --git a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala similarity index 55% rename from core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala rename to core/src/main/scala/org/apache/spark/network/ConnectionId.scala index ea94313a4ab59..ffaab677d411a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -15,19 +15,20 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package org.apache.spark.network -/** - * A function with no return value. - */ -// This allows Java users to write void methods without having to return Unit. -abstract class VoidFunction[T] extends Serializable { - @throws(classOf[Exception]) - def call(t: T) : Unit +private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { + override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId } -// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly -// return Unit), so it is implicitly converted to a Function1[T, Unit]: -object VoidFunction { - implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x)) +private[spark] object ConnectionId { + + def createConnectionIdFromString(connectionIdString: String): ConnectionId = { + val res = connectionIdString.split("_").map(_.trim()) + if (res.size != 3) { + throw new Exception("Error converting ConnectionId string: " + connectionIdString + + " to a ConnectionId Object") + } + new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) + } } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 3dd82bee0b5fd..a75130cba2a2e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -21,6 +21,9 @@ import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ +import java.net._ +import java.util.concurrent.atomic.AtomicInteger + import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} import scala.collection.mutable.ArrayBuffer @@ -28,13 +31,15 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue + import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration._ import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{SystemClock, Utils} -private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging { +private[spark] class ConnectionManager(port: Int, conf: SparkConf, + securityManager: SecurityManager) extends Logging { class MessageStatus( val message: Message, @@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private val selector = SelectorProvider.provider.openSelector() + // default to 30 second timeout waiting for authentication + private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), conf.getInt("spark.core.connection.handler.threads.max", 60), @@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() + // used to track the SendingConnections waiting to do SASL negotiation + private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] + with SynchronizedMap[ConnectionId, SendingConnection] private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] @@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private val authEnabled = securityManager.isAuthenticationEnabled() + serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) serverChannel.socket.setReceiveBufferSize(256 * 1024) @@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) + // used in combination with the ConnectionManagerId to create unique Connection ids + // to be able to track asynchronous messages + private val idCount: AtomicInteger = new AtomicInteger(1) + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } @@ -125,7 +142,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } finally { writeRunnableStarted.synchronized { writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() + val needReregister = register || conn.resetForceReregister() if (needReregister && conn.changeInterestForWrite()) { conn.registerInterest() } @@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi // accept them all in a tight loop. non blocking accept with no processing, should be fine while (newChannel != null) { try { - val newConnection = new ReceivingConnection(newChannel, selector) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) @@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi logInfo("Removing SendingConnection to " + sendingConnectionManagerId) connectionsById -= sendingConnectionManagerId + connectionsAwaitingSasl -= connection.connectionId messageStatuses.synchronized { messageStatuses @@ -481,7 +500,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val creationTime = System.currentTimeMillis def run() { logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message) + handleMessage(connectionManagerId, message, connection) logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") } } @@ -489,10 +508,133 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi /*handleMessage(connection, message)*/ } - private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { + private def handleClientAuthentication( + waitingConn: SendingConnection, + securityMsg: SecurityMessage, + connectionId : ConnectionId) { + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll(); + } + return + } else { + var replyToken : Array[Byte] = null + try { + replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken); + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll() + } + return + } + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId.toString()) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) + } catch { + case e: Exception => { + logError("Error handling sasl client authentication", e) + waitingConn.close() + throw new Exception("Error evaluating sasl response: " + e) + } + } + } + } + + private def handleServerAuthentication( + connection: Connection, + securityMsg: SecurityMessage, + connectionId: ConnectionId) { + if (!connection.isSaslComplete()) { + logDebug("saslContext not established") + var replyToken : Array[Byte] = null + try { + connection.synchronized { + if (connection.sparkSaslServer == null) { + logDebug("Creating sasl Server") + connection.sparkSaslServer = new SparkSaslServer(securityManager) + } + } + replyToken = connection.sparkSaslServer.response(securityMsg.getToken) + if (connection.isSaslComplete()) { + logDebug("Server sasl completed: " + connection.connectionId) + } else { + logDebug("Server sasl not completed: " + connection.connectionId) + } + if (replyToken != null) { + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security Message") + sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) + } + } catch { + case e: Exception => { + logError("Error in server auth negotiation: " + e) + // It would probably be better to send an error message telling other side auth failed + // but for now just close + connection.close() + } + } + } else { + logDebug("connection already established for this connection id: " + connection.connectionId) + } + } + + + private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { + if (bufferMessage.isSecurityNeg) { + logDebug("This is security neg message") + + // parse as SecurityMessage + val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) + val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) + + connectionsAwaitingSasl.get(connectionId) match { + case Some(waitingConn) => { + // Client - this must be in response to us doing Send + logDebug("Client handleAuth for id: " + waitingConn.connectionId) + handleClientAuthentication(waitingConn, securityMsg, connectionId) + } + case None => { + // Server - someone sent us something and we haven't authenticated yet + logDebug("Server handleAuth for id: " + connectionId) + handleServerAuthentication(conn, securityMsg, connectionId) + } + } + return true + } else { + if (!conn.isSaslComplete()) { + // We could handle this better and tell the client we need to do authentication + // negotiation, but for now just ignore them. + logError("message sent that is not security negotiation message on connection " + + "not authenticated yet, ignoring it!!") + return true + } + } + return false + } + + private def handleMessage( + connectionManagerId: ConnectionManagerId, + message: Message, + connection: Connection) { logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { + if (authEnabled) { + val res = handleAuthentication(connection, bufferMessage) + if (res == true) { + // message was security negotiation so skip the rest + logDebug("After handleAuth result was true, returning") + return + } + } if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { @@ -541,20 +683,124 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } } + private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { + // see if we need to do sasl before writing + // this should only be the first negotiation as the Client!!! + if (!conn.isSaslComplete()) { + conn.synchronized { + if (conn.sparkSaslClient == null) { + conn.sparkSaslClient = new SparkSaslClient(securityManager) + var firstResponse: Array[Byte] = null + try { + firstResponse = conn.sparkSaslClient.firstToken() + var securityMsg = SecurityMessage.fromResponse(firstResponse, + conn.connectionId.toString()) + var message = securityMsg.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + connectionsAwaitingSasl += ((conn.connectionId, conn)) + sendSecurityMessage(connManagerId, message) + logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) + } catch { + case e: Exception => { + logError("Error getting first response from the SaslClient.", e) + conn.close() + throw new Exception("Error getting first response from the SaslClient") + } + } + } + } + } else { + logDebug("Sasl already established ") + } + } + + // allow us to add messages to the inbox for doing sasl negotiating + private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, + newConnectionId) + logInfo("creating new sending connection for security! " + newConnectionId ) + registerRequests.enqueue(newConnection) + + newConnection + } + // I removed the lookupKey stuff as part of merge ... should I re-add it ? + // We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + message.senderAddress = id.toSocketAddress() + logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") + val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) + + //send security message until going connection has been authenticated + connection.send(message) + + wakeupSelector() + } + private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, + newConnectionId) + logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) newConnection } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it - // useful in our test-env ... If we do re-add it, we should consistently use it everywhere I - // guess ? val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) + if (authEnabled) { + checkSendAuthFirst(connectionManagerId, connection) + } message.senderAddress = id.toSocketAddress() + logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + + "connectionid: " + connection.connectionId) + + if (authEnabled) { + // if we aren't authenticated yet lets block the senders until authentication completes + try { + connection.getAuthenticated().synchronized { + val clock = SystemClock + val startTime = clock.getTime() + + while (!connection.isSaslComplete()) { + logDebug("getAuthenticated wait connectionid: " + connection.connectionId) + // have timeout in case remote side never responds + connection.getAuthenticated().wait(500) + if (((clock.getTime() - startTime) >= (authTimeout * 1000)) + && (!connection.isSaslComplete())) { + // took to long to authenticate the connection, something probably went wrong + throw new Exception("Took to long for authentication to " + connectionManagerId + + ", waited " + authTimeout + "seconds, failing.") + } + } + } + } catch { + case e: Exception => logError("Exception while waiting for authentication.", e) + + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(message.id) + s match { + case Some(msgStatus) => { + messageStatuses -= message.id + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.synchronized { + msgStatus.attempted = true + msgStatus.acked = false + msgStatus.markDone() + } + } + case None => { + logError("no messageStatus for failed message id: " + message.id) + } + } + } + } + } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) @@ -606,7 +852,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private[spark] object ConnectionManager { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 20fe67661844f..7caccfdbb44f9 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -27,6 +27,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var started = false var startTime = -1L var finishTime = -1L + var isSecurityNeg = false def size: Int diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index 9bcbc6141a502..ead663ede7a1c 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { // No need to change this, at 'use' time, we do a reverse lookup of the hostname. @@ -40,6 +41,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + putInt(securityNeg). putInt(ip.size). put(ip). putInt(port). @@ -48,12 +50,13 @@ private[spark] class MessageChunkHeader( } override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 + val HEADER_SIZE = 44 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -64,11 +67,13 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 9976255c7e251..3c09a713c6fe0 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -18,12 +18,12 @@ package org.apache.spark.network import java.nio.ByteBuffer - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object ReceiverTest { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala new file mode 100644 index 0000000000000..0d9f743b3624b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.StringBuilder + +import org.apache.spark._ +import org.apache.spark.network._ + +/** + * SecurityMessage is class that contains the connectionId and sasl token + * used in SASL negotiation. SecurityMessage has routines for converting + * it to and from a BufferMessage so that it can be sent by the ConnectionManager + * and easily consumed by users when received. + * The api was modeled after BlockMessage. + * + * The connectionId is the connectionId of the client side. Since + * message passing is asynchronous and its possible for the server side (receiving) + * to get multiple different types of messages on the same connection the connectionId + * is used to know which connnection the security message is intended for. + * + * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side + * is acting as a client and connecting to node_1. SASL negotiation has to occur + * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. + * node_1 receives the message from node_0 but before it can process it and send a response, + * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 + * and sends a security message of its own to authenticate as a client. Now node_0 gets + * the message and it needs to decide if this message is in response to it being a client + * (from the first send) or if its just node_1 trying to connect to it to send data. This + * is where the connectionId field is used. node_0 can lookup the connectionId to see if + * it is in response to it being a client or if its in response to someone sending other data. + * + * The format of a SecurityMessage as its sent is: + * - Length of the ConnectionId + * - ConnectionId + * - Length of the token + * - Token + */ +private[spark] class SecurityMessage() extends Logging { + + private var connectionId: String = null + private var token: Array[Byte] = null + + def set(byteArr: Array[Byte], newconnectionId: String) { + if (byteArr == null) { + token = new Array[Byte](0) + } else { + token = byteArr + } + connectionId = newconnectionId + } + + /** + * Read the given buffer and set the members of this class. + */ + def set(buffer: ByteBuffer) { + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + connectionId = idBuilder.toString() + + val tokenLength = buffer.getInt() + token = new Array[Byte](tokenLength) + if (tokenLength > 0) { + buffer.get(token, 0, tokenLength) + } + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getConnectionId: String = { + return connectionId + } + + def getToken: Array[Byte] = { + return token + } + + /** + * Create a BufferMessage that can be sent by the ConnectionManager containing + * the security information from this class. + * @return BufferMessage + */ + def toBufferMessage: BufferMessage = { + val startTime = System.currentTimeMillis + val buffers = new ArrayBuffer[ByteBuffer]() + + // 4 bytes for the length of the connectionId + // connectionId is of type char so multiple the length by 2 to get number of bytes + // 4 bytes for the length of token + // token is a byte buffer so just take the length + var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) + buffer.putInt(connectionId.length()) + connectionId.foreach((x: Char) => buffer.putChar(x)) + buffer.putInt(token.length) + + if (token.length > 0) { + buffer.put(token) + } + buffer.flip() + buffers += buffer + + var message = Message.createBufferMessage(buffers) + logDebug("message total size is : " + message.size) + message.isSecurityNeg = true + return message + } + + override def toString: String = { + "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" + } +} + +private[spark] object SecurityMessage { + + /** + * Convert the given BufferMessage to a SecurityMessage by parsing the contents + * of the BufferMessage and populating the SecurityMessage fields. + * @param bufferMessage is a BufferMessage that was received + * @return new SecurityMessage + */ + def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(bufferMessage) + newSecurityMessage + } + + /** + * Create a SecurityMessage to send from a given saslResponse. + * @param response is the response to a challenge from the SaslClient or Saslserver + * @param connectionId the client connectionId we are negotiation authentication for + * @return a new SecurityMessage + */ + def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(response, connectionId) + newSecurityMessage + } +} diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index 646f8425d9551..aac2c24a46faa 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -18,8 +18,7 @@ package org.apache.spark.network import java.nio.ByteBuffer - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { def main(args: Array[String]) { @@ -32,8 +31,8 @@ private[spark] object SenderTest { val targetHost = args(0) val targetPort = args(1).toInt val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - - val manager = new ConnectionManager(0, new SparkConf) + val conf = new SparkConf + val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 699a10c96c227..8561711931047 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} +import org.apache.spark.serializer.Serializer private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -66,10 +67,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: private type CoGroupValue = (Any, Int) // Int is dependency number private type CoGroupCombiner = Seq[CoGroup] - private var serializerClass: String = null + private var serializer: Serializer = null - def setSerializer(cls: String): CoGroupedRDD[K] = { - serializerClass = cls + def setSerializer(serializer: Serializer): CoGroupedRDD[K] = { + this.serializer = serializer this } @@ -80,7 +81,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializerClass) + new ShuffleDependency[Any, Any](rdd, part, serializer) } } } @@ -113,18 +114,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => // Read them from the parent val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] rddIterators += ((it, depNum)) - } - case ShuffleCoGroupSplitDep(shuffleId) => { + + case ShuffleCoGroupSplitDep(shuffleId) => // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf) + val ser = Serializer.getSerializer(serializer) val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) rddIterators += ((it, depNum)) - } } if (!externalSorting) { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index a374fc4a871b0..100ddb360732a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -18,8 +18,10 @@ package org.apache.spark.rdd import java.io.EOFException +import scala.collection.immutable.Map import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf @@ -43,6 +45,23 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp override def hashCode(): Int = 41 * (41 + rddId) + idx override val index: Int = idx + + /** + * Get any environment variables that should be added to the users environment when running pipes + * @return a Map with the environment variables and corresponding values, it could be empty + */ + def getPipeEnvVars(): Map[String, String] = { + val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) { + val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit] + // map_input_file is deprecated in favor of mapreduce_map_input_file but set both + // since its not removed yet + Map("map_input_file" -> is.getPath().toString(), + "mapreduce_map_input_file" -> is.getPath().toString()) + } else { + Map() + } + envVars + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index d29a1a9881cd4..447deafff53cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -30,23 +30,21 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} -import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} // SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark. import org.apache.hadoop.mapred.SparkHadoopWriter -import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.serializer.Serializer import org.apache.spark.util.SerializableHyperLogLog /** @@ -76,7 +74,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializerClass: String = null): RDD[(K, C)] = { + serializer: Serializer = null): RDD[(K, C)] = { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (getKeyClass().isArray) { if (mapSideCombine) { @@ -96,13 +94,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) aggregator.combineValuesByKey(iter, context) }, preservesPartitioning = true) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) - .setSerializer(serializerClass) + .setSerializer(serializer) partitioned.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context)) }, preservesPartitioning = true) } else { // Don't apply map-side combiner. - val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) + val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer) values.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) @@ -196,6 +194,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } /** Alias for reduceByKeyLocally */ + @deprecated("Use reduceByKeyLocally", "1.0.0") def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) /** Count the number of elements for each key, and return the result to the master as a Map. */ @@ -425,7 +424,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * Return the key-value pairs in this RDD to the master as a Map. */ def collectAsMap(): Map[K, V] = { - val data = self.toArray() + val data = self.collect() val map = new mutable.HashMap[K, V] map.sizeHint(data.length) data.foreach { case (k, v) => map.put(k, v) } @@ -604,46 +603,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) - val wrappedConf = new SerializableWritable(job.getConfiguration) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = self.id - def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = { - // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it - // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = outputFormatClass.newInstance - format match { - case c: Configurable => c.setConf(wrappedConf.value) - case _ => () - } - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] - while (iter.hasNext) { - val (k, v) = iter.next() - writer.write(k, v) - } - writer.close(hadoopContext) - committer.commitTask(hadoopContext) - return 1 - } - val jobFormat = outputFormatClass.newInstance - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - val count = self.context.runJob(self, writeShard _).sum - jobCommitter.commitJob(jobTaskContext) + job.setOutputFormatClass(outputFormatClass) + job.getConfiguration.set("mapred.output.dir", path) + saveAsNewAPIHadoopDataset(job.getConfiguration) } /** @@ -689,6 +651,59 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) saveAsHadoopDataset(conf) } + /** + * Output the RDD to any Hadoop-supported storage system with new Hadoop API, using a Hadoop + * Configuration object for that storage system. The Conf should set an OutputFormat and any + * output paths required (e.g. a table name to write to) in the same way as it would be + * configured for a Hadoop MapReduce job. + */ + def saveAsNewAPIHadoopDataset(conf: Configuration) { + val job = new NewAPIHadoopJob(conf) + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + val jobtrackerID = formatter.format(new Date()) + val stageId = self.id + val wrappedConf = new SerializableWritable(job.getConfiguration) + val outfmt = job.getOutputFormatClass + val jobFormat = outfmt.newInstance + + if (jobFormat.isInstanceOf[NewFileOutputFormat[_, _]]) { + // FileOutputFormat ignores the filesystem parameter + jobFormat.checkOutputSpecs(job) + } + + def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = { + // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it + // around by taking a mod. We expect that no task will be attempted 2 billion times. + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + /* "reduce task" */ + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + attemptNumber) + val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val format = outfmt.newInstance + format match { + case c: Configurable => c.setConf(wrappedConf.value) + case _ => () + } + val committer = format.getOutputCommitter(hadoopContext) + committer.setupTask(hadoopContext) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + while (iter.hasNext) { + val (k, v) = iter.next() + writer.write(k, v) + } + writer.close(hadoopContext) + committer.commitTask(hadoopContext) + return 1 + } + + val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) + val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) + jobCommitter.setupJob(jobTaskContext) + self.context.runJob(self, writeShard _) + jobCommitter.commitJob(jobTaskContext) + } + /** * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for * that storage system. The JobConf should set an OutputFormat and any output paths required @@ -696,10 +711,10 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * MapReduce job. */ def saveAsHadoopDataset(conf: JobConf) { - val outputFormatClass = conf.getOutputFormat + val outputFormatInstance = conf.getOutputFormat val keyClass = conf.getOutputKeyClass val valueClass = conf.getOutputValueClass - if (outputFormatClass == null) { + if (outputFormatInstance == null) { throw new SparkException("Output format class not set") } if (keyClass == null) { @@ -712,6 +727,12 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") + if (outputFormatInstance.isInstanceOf[FileOutputFormat[_, _]]) { + // FileOutputFormat ignores the filesystem parameter + val ignoredFs = FileSystem.get(conf) + conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf) + } + val writer = new SparkHadoopWriter(conf) writer.preSetup() diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index abd4414e81f5c..4250a9d02f764 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -28,6 +28,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkEnv, TaskContext} + /** * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. @@ -59,6 +60,13 @@ class PipedRDD[T: ClassTag]( val currentEnvVars = pb.environment() envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) } + // for compatibility with Hadoop which sets these env variables + // so the user code can access the input filename + if (split.isInstanceOf[HadoopPartition]) { + val hadoopSplit = split.asInstanceOf[HadoopPartition] + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) + } + val proc = pb.start() val env = SparkEnv.get diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 50320f40350cd..ddb901246d360 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -126,14 +126,6 @@ abstract class RDD[T: ClassTag]( this } - /** User-defined generator of this RDD*/ - @transient var generator = Utils.getCallSiteInfo.firstUserClass - - /** Reset generator*/ - def setGenerator(_generator: String) = { - generator = _generator - } - /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. This can only be used to assign a new storage level if the RDD does not @@ -318,6 +310,7 @@ abstract class RDD[T: ClassTag]( * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { + require(fraction >= 0.0, "Invalid fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) } else { @@ -352,6 +345,10 @@ abstract class RDD[T: ClassTag]( throw new IllegalArgumentException("Negative number of elements requested") } + if (initialCount == 0) { + return new Array[T](0) + } + if (initialCount > Integer.MAX_VALUE - 1) { maxSelected = Integer.MAX_VALUE - 1 } else { @@ -370,7 +367,7 @@ abstract class RDD[T: ClassTag]( var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for thei initial size + // this shouldn't happen often because we use a big multiplier for the initial size while (samples.length < total) { samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() } @@ -543,7 +540,8 @@ abstract class RDD[T: ClassTag]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def mapWith[A: ClassTag, U: ClassTag] + @deprecated("use mapPartitionsWithIndex", "1.0.0") + def mapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = { mapPartitionsWithIndex((index, iter) => { @@ -557,7 +555,8 @@ abstract class RDD[T: ClassTag]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def flatMapWith[A: ClassTag, U: ClassTag] + @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0") + def flatMapWith[A, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = { mapPartitionsWithIndex((index, iter) => { @@ -571,7 +570,8 @@ abstract class RDD[T: ClassTag]( * This additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) { + @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") + def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit) { mapPartitionsWithIndex { (index, iter) => val a = constructA(index) iter.map(t => {f(t, a); t}) @@ -583,7 +583,8 @@ abstract class RDD[T: ClassTag]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + @deprecated("use mapPartitionsWithIndex and filter", "1.0.0") + def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { mapPartitionsWithIndex((index, iter) => { val a = constructA(index) iter.filter(t => p(t, a)) @@ -662,6 +663,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. */ + @deprecated("use collect", "1.0.0") def toArray(): Array[T] = collect() /** @@ -954,6 +956,18 @@ abstract class RDD[T: ClassTag]( */ def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse) + /** + * Returns the max of this RDD as defined by the implicit Ordering[T]. + * @return the maximum element of the RDD + * */ + def max()(implicit ord: Ordering[T]):T = this.reduce(ord.max) + + /** + * Returns the min of this RDD as defined by the implicit Ordering[T]. + * @return the minimum element of the RDD + * */ + def min()(implicit ord: Ordering[T]):T = this.reduce(ord.min) + /** * Save this RDD as a text file, using string representations of elements. */ @@ -1027,8 +1041,9 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE - /** Record user function generating this RDD. */ - @transient private[spark] val origin = sc.getCallSite() + /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ + @transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo + private[spark] def getCreationSite = Utils.formatCallSiteInfo(creationSiteInfo) private[spark] def elementClassTag: ClassTag[T] = classTag[T] @@ -1091,10 +1106,7 @@ abstract class RDD[T: ClassTag]( } override def toString: String = "%s%s[%d] at %s".format( - Option(name).map(_ + " ").getOrElse(""), - getClass.getSimpleName, - id, - origin) + Option(name).map(_ + " ").getOrElse(""), getClass.getSimpleName, id, getCreationSite) def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index b50307cfa49b7..4ceea557f569c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -26,13 +26,13 @@ import cern.jet.random.engine.DRand import org.apache.spark.{Partition, TaskContext} -@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0") +@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0") private[spark] class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { override val index: Int = prev.index } -@deprecated("Replaced by PartitionwiseSampledRDD", "1.0") +@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0") class SampledRDD[T: ClassTag]( prev: RDD[T], withReplacement: Boolean, diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 0bbda25a905cd..02660ea6a45c5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx @@ -38,15 +39,15 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( part: Partitioner) extends RDD[P](prev.context, Nil) { - private var serializerClass: String = null + private var serializer: Serializer = null - def setSerializer(cls: String): ShuffledRDD[K, V, P] = { - serializerClass = cls + def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { + this.serializer = serializer this } override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializerClass)) + List(new ShuffleDependency(prev, part, serializer)) } override val partitioner = Some(part) @@ -57,8 +58,8 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, - SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)) + val ser = Serializer.getSerializer(serializer) + SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 5fe9f363db453..9a09c05bbc959 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -30,6 +30,7 @@ import org.apache.spark.Partitioner import org.apache.spark.ShuffleDependency import org.apache.spark.SparkEnv import org.apache.spark.TaskContext +import org.apache.spark.serializer.Serializer /** * An optimized version of cogroup for set difference/subtraction. @@ -53,10 +54,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { - private var serializerClass: String = null + private var serializer: Serializer = null - def setSerializer(cls: String): SubtractedRDD[K, V, W] = { - serializerClass = cls + def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = { + this.serializer = serializer this } @@ -67,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializerClass) + new ShuffleDependency(rdd, part, serializer) } } } @@ -92,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) + val ser = Serializer.getSerializer(serializer) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -105,14 +106,13 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } } def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - } - case ShuffleCoGroupSplitDep(shuffleId) => { + + case ShuffleCoGroupSplitDep(shuffleId) => val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context, serializer) + context, ser) iter.foreach(op) - } } // the first dep is rdd1; add all values to the map integrate(partition.deps(0), t => getSeq(t._1) += t._2) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dc5b25d845dc2..d83d0341c61ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -279,7 +279,7 @@ class DAGScheduler( } else { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) } stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 9d75d7c4ad69a..01cbcc390c6cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -81,7 +81,7 @@ class JobLogger(val user: String, val logDirName: String) /** * Create a log file for one job * @param jobID ID of the job - * @exception FileNotFoundException Fail to create log file + * @throws FileNotFoundException Fail to create log file */ protected def createLogWriter(jobID: Int) { try { @@ -213,14 +213,10 @@ class JobLogger(val user: String, val logDirName: String) * @param indent Indent number before info */ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { + val cacheStr = if (rdd.getStorageLevel != StorageLevel.NONE) "CACHED" else "NONE" val rddInfo = - if (rdd.getStorageLevel != StorageLevel.NONE) { - "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " + - rdd.origin + " " + rdd.generator - } else { - "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " + - rdd.origin + " " + rdd.generator - } + s"RDD_ID=$rdd.id ${getRddName(rdd)} $cacheStr " + + s"${rdd.getCreationSite} ${rdd.creationSiteInfo.firstUserClass}" jobLogInfo(jobID, indentString(indent) + rddInfo, false) rdd.dependencies.foreach { case shufDep: ShuffleDependency[_, _] => @@ -275,7 +271,6 @@ class JobLogger(val user: String, val logDirName: String) " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead case None => "" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index eefc8c232b564..f1924a4573b21 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler /** * A backend interface for scheduling systems that allows plugging in different ones under - * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as + * TaskSchedulerImpl. We assume a Mesos-like model where the application gets resource offers as * machines become available and can launch tasks on them. */ private[spark] trait SchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 77789031f464a..2a9edf4a76b97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -26,6 +26,7 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDDCheckpointData +import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} @@ -153,7 +154,7 @@ private[spark] class ShuffleMapTask( try { // Obtain all the block writers for shuffle blocks. - val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf) + val ser = Serializer.getSerializer(dep.serializer) shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) // Write the map output to its associated buckets. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index a78b0186b9eab..5c1fc30e4a557 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -100,7 +100,7 @@ private[spark] class Stage( id } - val name = callSite.getOrElse(rdd.origin) + val name = callSite.getOrElse(rdd.getCreationSite) override def toString = "Stage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 1cdfed1d7005e..92616c997e20c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler. + * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks * for a single SparkContext. These schedulers get sets of tasks submitted to them from the * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8df37c247d0d4..23b06612fd7ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -25,6 +25,7 @@ import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet +import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -207,9 +208,11 @@ private[spark] class TaskSchedulerImpl( } } - // Build a list of tasks to assign to each worker - val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = offers.map(o => o.cores).toArray + // Randomly shuffle offers to avoid always placing tasks on the same set of workers. + val shuffledOffers = Random.shuffle(offers) + // Build a list of tasks to assign to each worker. + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue() for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -222,9 +225,9 @@ private[spark] class TaskSchedulerImpl( for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { do { launchedTask = false - for (i <- 0 until offers.size) { - val execId = offers(i).executorId - val host = offers(i).host + for (i <- 0 until shuffledOffers.size) { + val execId = shuffledOffers(i).executorId + val host = shuffledOffers(i).host for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) { tasks(i) += task val tid = task.taskId diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 1a4b7e599c01e..5ea4557bbf56a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -26,13 +26,14 @@ import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} +import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted, + SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{Clock, SystemClock} /** - * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of + * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of * each task, retries tasks if they fail (up to a limited number of times), and * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, @@ -41,7 +42,7 @@ import org.apache.spark.util.{Clock, SystemClock} * THREADING: This class is designed to only be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. * - * @param sched the ClusterScheduler associated with the TaskSetManager + * @param sched the TaskSchedulerImpl associated with the TaskSetManager * @param taskSet the TaskSet to manage scheduling for * @param maxTaskFailures if any particular task fails more than this number of times, the entire * task set will be aborted diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index ba6bab3f91a65..810b36cddf835 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -21,4 +21,4 @@ package org.apache.spark.scheduler * Represents free resources available on an executor. */ private[spark] -class WorkerOffer(val executorId: String, val host: String, val cores: Int) +case class WorkerOffer(executorId: String, host: String, cores: Int) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 379e02eb9a437..fad03731572e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -54,6 +54,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] private val freeCores = new HashMap[String, Int] + private val totalCores = new HashMap[String, Int] private val addressToExecutorId = new HashMap[Address, String] override def preStart() { @@ -76,6 +77,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A sender ! RegisteredExecutor(sparkProperties) executorActor(executorId) = sender executorHost(executorId) = Utils.parseHostPort(hostPort)._1 + totalCores(executorId) = cores freeCores(executorId) = cores executorAddress(executorId) = sender.path.address addressToExecutorId(sender.path.address) = executorId @@ -147,10 +149,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A def removeExecutor(executorId: String, reason: String) { if (executorActor.contains(executorId)) { logInfo("Executor " + executorId + " disconnected, so removing it") - val numCores = freeCores(executorId) - addressToExecutorId -= executorAddress(executorId) + val numCores = totalCores(executorId) executorActor -= executorId executorHost -= executorId + addressToExecutorId -= executorAddress(executorId) + executorAddress -= executorId + totalCores -= executorId freeCores -= executorId totalCoreCount.addAndGet(-numCores) scheduler.executorLost(executorId, SlaveLost(reason)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c576beb0c0d38..bcf0ce19a54cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -203,7 +203,7 @@ private[spark] class MesosSchedulerBackend( getResource(offer.getResourcesList, "cpus").toInt) } - // Call into the ClusterScheduler + // Call into the TaskSchedulerImpl val taskLists = scheduler.resourceOffers(offerableWorkers) // Build a list of Mesos tasks for each slave diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 50f7e79e97dd8..16e2f5cf3076d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -35,7 +35,7 @@ private case class KillTask(taskId: Long) /** * Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on * LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend - * and the ClusterScheduler. + * and the TaskSchedulerImpl. */ private[spark] class LocalActor( scheduler: TaskSchedulerImpl, @@ -76,7 +76,7 @@ private[spark] class LocalActor( /** * LocalBackend is used when running a local version of Spark where the executor, backend, and - * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks + * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 33c1705ad7c58..18a68b05fa853 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -23,16 +23,34 @@ import java.nio.ByteBuffer import org.apache.spark.SparkConf import org.apache.spark.util.ByteBufferInputStream -private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { - val objOut = new ObjectOutputStream(out) - def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } +private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int) + extends SerializationStream { + private val objOut = new ObjectOutputStream(out) + private var counter = 0 + + /** + * Calling reset to avoid memory leak: + * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api + * But only call it every 10,000th time to avoid bloated serialization streams (when + * the stream 'resets' object class descriptions have to be re-written) + */ + def writeObject[T](t: T): SerializationStream = { + objOut.writeObject(t) + if (counterReset > 0 && counter >= counterReset) { + objOut.reset() + counter = 0 + } else { + counter += 1 + } + this + } def flush() { objOut.flush() } def close() { objOut.close() } } private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) extends DeserializationStream { - val objIn = new ObjectInputStream(in) { + private val objIn = new ObjectInputStream(in) { override def resolveClass(desc: ObjectStreamClass) = Class.forName(desc.getName, false, loader) } @@ -41,7 +59,7 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance extends SerializerInstance { +private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { def serialize[T](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) @@ -63,7 +81,7 @@ private[spark] class JavaSerializerInstance extends SerializerInstance { } def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s) + new JavaSerializationStream(s, counterReset) } def deserializeStream(s: InputStream): DeserializationStream = { @@ -78,6 +96,16 @@ private[spark] class JavaSerializerInstance extends SerializerInstance { /** * A Spark serializer that uses Java's built-in serialization. */ -class JavaSerializer(conf: SparkConf) extends Serializer { - def newInstance(): SerializerInstance = new JavaSerializerInstance +class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { + private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000) + + def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) + + override def writeExternal(out: ObjectOutput) { + out.writeInt(counterReset) + } + + override def readExternal(in: ObjectInput) { + counterReset = in.readInt() + } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 920490f9d0d61..6b6d814c1fe92 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -34,10 +34,14 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. */ -class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging { - private val bufferSize = { - conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 - } +class KryoSerializer(conf: SparkConf) + extends org.apache.spark.serializer.Serializer + with Logging + with Serializable { + + private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 + private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) + private val registrator = conf.getOption("spark.kryo.registrator") def newKryoOutput() = new KryoOutput(bufferSize) @@ -48,7 +52,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. - kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true)) + kryo.setReferences(referenceTracking) for (cls <- KryoSerializer.toRegister) kryo.register(cls) @@ -58,7 +62,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial // Allow the user to register their own classes by setting spark.kryo.registrator try { - for (regCls <- conf.getOption("spark.kryo.registrator")) { + for (regCls <- registrator) { logDebug("Running user registrator: " + regCls) val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 16677ab54be04..099143494b851 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -23,21 +23,31 @@ import java.nio.ByteBuffer import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import org.apache.spark.util.{ByteBufferInputStream, NextIterator} +import org.apache.spark.SparkEnv /** * A serializer. Because some serialization libraries are not thread safe, this class is used to * create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual * serialization and are guaranteed to only be called from one thread at a time. * - * Implementations of this trait should have a zero-arg constructor or a constructor that accepts a - * [[org.apache.spark.SparkConf]] as parameter. If both constructors are defined, the latter takes - * precedence. + * Implementations of this trait should implement: + * 1. a zero-arg constructor or a constructor that accepts a [[org.apache.spark.SparkConf]] + * as parameter. If both constructors are defined, the latter takes precedence. + * + * 2. Java serialization interface. */ trait Serializer { def newInstance(): SerializerInstance } +object Serializer { + def getSerializer(serializer: Serializer): Serializer = { + if (serializer == null) SparkEnv.get.serializer else serializer + } +} + + /** * An instance of a serializer, for use by one thread at a time. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala deleted file mode 100644 index 65ac0155f45e7..0000000000000 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.serializer - -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark.SparkConf - -/** - * A service that returns a serializer object given the serializer's class name. If a previous - * instance of the serializer object has been created, the get method returns that instead of - * creating a new one. - */ -private[spark] class SerializerManager { - // TODO: Consider moving this into SparkConf itself to remove the global singleton. - - private val serializers = new ConcurrentHashMap[String, Serializer] - private var _default: Serializer = _ - - def default = _default - - def setDefault(clsName: String, conf: SparkConf): Serializer = { - _default = get(clsName, conf) - _default - } - - def get(clsName: String, conf: SparkConf): Serializer = { - if (clsName == null) { - default - } else { - var serializer = serializers.get(clsName) - if (serializer != null) { - // If the serializer has been created previously, reuse that. - serializer - } else this.synchronized { - // Otherwise, create a new one. But make sure no other thread has attempted - // to create another new one at the same time. - serializer = serializers.get(clsName) - if (serializer == null) { - val clsLoader = Thread.currentThread.getContextClassLoader - val cls = Class.forName(clsName, true, clsLoader) - - // First try with the constructor that takes SparkConf. If we can't find one, - // use a no-arg constructor instead. - try { - val constructor = cls.getConstructor(classOf[SparkConf]) - serializer = constructor.newInstance(conf).asInstanceOf[Serializer] - } catch { - case _: NoSuchMethodException => - val constructor = cls.getConstructor() - serializer = constructor.newInstance().asInstanceOf[Serializer] - } - - serializers.put(clsName, serializer) - } - serializer - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 925022e7fe6fb..bcfc39146a61e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -44,9 +44,13 @@ import org.apache.spark.util.Utils */ private[storage] -trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] - with Logging with BlockFetchTracker { +trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { def initialize() + def totalBlocks: Int + def numLocalBlocks: Int + def numRemoteBlocks: Int + def fetchWaitTime: Long + def remoteBytesRead: Long } @@ -74,7 +78,6 @@ object BlockFetcherIterator { import blockManager._ private var _remoteBytesRead = 0L - private var _remoteFetchTime = 0L private var _fetchWaitTime = 0L if (blocksByAddress == null) { @@ -120,7 +123,6 @@ object BlockFetcherIterator { future.onSuccess { case Some(message) => { val fetchDone = System.currentTimeMillis() - _remoteFetchTime += fetchDone - fetchStart val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) for (blockMessage <- blockMessageArray) { @@ -233,7 +235,15 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - //an iterator that will read fetched blocks off the queue as they arrive. + override def totalBlocks: Int = numLocal + numRemote + override def numLocalBlocks: Int = numLocal + override def numRemoteBlocks: Int = numRemote + override def fetchWaitTime: Long = _fetchWaitTime + override def remoteBytesRead: Long = _remoteBytesRead + + + // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue + // as they arrive. @volatile protected var resultsGotten = 0 override def hasNext: Boolean = resultsGotten < _numBlocksToFetch @@ -251,14 +261,6 @@ object BlockFetcherIterator { } (result.blockId, if (result.failed) None else Some(result.deserialize())) } - - // Implementing BlockFetchTracker trait. - override def totalBlocks: Int = numLocal + numRemote - override def numLocalBlocks: Int = numLocal - override def numRemoteBlocks: Int = numRemote - override def remoteFetchTime: Long = _remoteFetchTime - override def fetchWaitTime: Long = _fetchWaitTime - override def remoteBytesRead: Long = _remoteBytesRead } // End of BasicBlockFetcherIterator diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a734ddc1ef702..1bf3f4db32ea7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,19 +29,26 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException, SecurityManager} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ +sealed trait Values + +case class ByteBufferValues(buffer: ByteBuffer) extends Values +case class IteratorValues(iterator: Iterator[Any]) extends Values +case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values + private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, val master: BlockManagerMaster, val defaultSerializer: Serializer, maxMemory: Long, - val conf: SparkConf) + val conf: SparkConf, + securityManager: SecurityManager) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) @@ -60,7 +67,7 @@ private[spark] class BlockManager( if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - val connectionManager = new ConnectionManager(0, conf) + val connectionManager = new ConnectionManager(0, conf, securityManager) implicit val futureExecContext = connectionManager.futureExecContext val blockManagerId = BlockManagerId( @@ -116,8 +123,9 @@ private[spark] class BlockManager( * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer, conf: SparkConf) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf) + serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf, + securityManager) } /** @@ -455,9 +463,7 @@ private[spark] class BlockManager( def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) : Long = { - val elements = new ArrayBuffer[Any] - elements ++= values - put(blockId, elements, level, tellMaster) + doPut(blockId, IteratorValues(values), level, tellMaster) } /** @@ -479,7 +485,7 @@ private[spark] class BlockManager( def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, tellMaster: Boolean = true) : Long = { require(values != null, "Values is null") - doPut(blockId, Left(values), level, tellMaster) + doPut(blockId, ArrayBufferValues(values), level, tellMaster) } /** @@ -488,10 +494,11 @@ private[spark] class BlockManager( def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { require(bytes != null, "Bytes is null") - doPut(blockId, Right(bytes), level, tellMaster) + doPut(blockId, ByteBufferValues(bytes), level, tellMaster) } - private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer], + private def doPut(blockId: BlockId, + data: Values, level: StorageLevel, tellMaster: Boolean = true): Long = { require(blockId != null, "BlockId is null") require(level != null && level.isValid, "StorageLevel is null or invalid") @@ -534,8 +541,9 @@ private[spark] class BlockManager( // If we're storing bytes, then initiate the replication before storing them locally. // This is faster as data is already serialized and ready to send. - val replicationFuture = if (data.isRight && level.replication > 1) { - val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper + val replicationFuture = if (data.isInstanceOf[ByteBufferValues] && level.replication > 1) { + // Duplicate doesn't copy the bytes, just creates a wrapper + val bufferView = data.asInstanceOf[ByteBufferValues].buffer.duplicate() Future { replicate(blockId, bufferView, level) } @@ -549,34 +557,43 @@ private[spark] class BlockManager( var marked = false try { - data match { - case Left(values) => { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will - // drop it to disk later if the memory store can't hold it. - val res = memoryStore.putValues(blockId, values, level, true) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator - } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - val res = diskStore.putValues(blockId, values, level, askForBytes) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => - } + if (level.useMemory) { + // Save it just to memory first, even if it also has useDisk set to true; we will + // drop it to disk later if the memory store can't hold it. + val res = data match { + case IteratorValues(iterator) => + memoryStore.putValues(blockId, iterator, level, true) + case ArrayBufferValues(array) => + memoryStore.putValues(blockId, array, level, true) + case ByteBufferValues(bytes) => { + bytes.rewind(); + memoryStore.putBytes(blockId, bytes, level) + } + } + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case Left(newIterator) => valuesAfterPut = newIterator + } + } else { + // Save directly to disk. + // Don't get back the bytes unless we replicate them. + val askForBytes = level.replication > 1 + + val res = data match { + case IteratorValues(iterator) => + diskStore.putValues(blockId, iterator, level, askForBytes) + case ArrayBufferValues(array) => + diskStore.putValues(blockId, array, level, askForBytes) + case ByteBufferValues(bytes) => { + bytes.rewind(); + diskStore.putBytes(blockId, bytes, level) } } - case Right(bytes) => { - bytes.rewind() - // Store it only in memory at first, even if useDisk is also set to true - (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level) - size = bytes.limit + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => } } @@ -605,8 +622,8 @@ private[spark] class BlockManager( // values and need to serialize and replicate them now: if (level.replication > 1) { data match { - case Right(bytes) => Await.ready(replicationFuture, Duration.Inf) - case Left(values) => { + case ByteBufferValues(bytes) => Await.ready(replicationFuture, Duration.Inf) + case _ => { val remoteStartTime = System.currentTimeMillis // Serialize the block if not already done if (bytesAfterPut == null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index b047644b88f48..9a9be047c7245 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -28,7 +28,7 @@ import org.apache.spark.Logging */ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -37,6 +37,9 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ + def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel, + returnValues: Boolean) : PutResult + def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) : PutResult diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index d1f07ddb24bb2..36ee4bcc41c66 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -37,7 +37,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage diskManager.getBlockLocation(blockId).length } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() @@ -52,6 +52,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime))) + return PutResult(bytes.limit(), Right(bytes.duplicate())) } override def putValues( @@ -59,13 +60,22 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) + : PutResult = { + return putValues(blockId, values.toIterator, level, returnValues) + } + + override def putValues( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean) : PutResult = { logDebug("Attempting to write values for block " + blockId) val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) - blockManager.dataSerializeStream(blockId, outputStream, values.iterator) + blockManager.dataSerializeStream(blockId, outputStream, values) val length = file.length val timeTaken = System.currentTimeMillis - startTime diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 18141756518c5..38836d44b04e8 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -49,7 +49,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = { // Work on a duplicate - since the original input might be used elsewhere. val bytes = _bytes.duplicate() bytes.rewind() @@ -59,8 +59,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) elements ++= values val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) tryToPut(blockId, elements, sizeEstimate, true) + PutResult(sizeEstimate, Left(values.toIterator)) } else { tryToPut(blockId, bytes, bytes.limit, false) + PutResult(bytes.limit(), Right(bytes.duplicate())) } } @@ -69,14 +71,33 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) - : PutResult = { - + : PutResult = { if (level.deserialized) { val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) tryToPut(blockId, values, sizeEstimate, true) - PutResult(sizeEstimate, Left(values.iterator)) + PutResult(sizeEstimate, Left(values.toIterator)) + } else { + val bytes = blockManager.dataSerialize(blockId, values.toIterator) + tryToPut(blockId, bytes, bytes.limit, false) + PutResult(bytes.limit(), Right(bytes.duplicate())) + } + } + + override def putValues( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean) + : PutResult = { + + if (level.deserialized) { + val valueEntries = new ArrayBuffer[Any]() + valueEntries ++= values + val sizeEstimate = SizeEstimator.estimate(valueEntries.asInstanceOf[AnyRef]) + tryToPut(blockId, valueEntries, sizeEstimate, true) + PutResult(sizeEstimate, Left(valueEntries.toIterator)) } else { - val bytes = blockManager.dataSerialize(blockId, values.iterator) + val bytes = blockManager.dataSerialize(blockId, values) tryToPut(blockId, bytes, bytes.limit, false) PutResult(bytes.limit(), Right(bytes.duplicate())) } @@ -215,13 +236,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd.isDefined && rddToAdd == getRddId(blockId)) { - logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + - "block from the same RDD") - return false + if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) { + selectedBlocks += blockId + selectedMemory += pair.getValue.size } - selectedBlocks += blockId - selectedMemory += pair.getValue.size } } @@ -243,6 +261,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } return true } else { + logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " + + "from the same RDD") return false } } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 1d81d006c0b29..36f2a0fd02724 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -24,6 +24,7 @@ import util.Random import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.{SecurityManager, SparkConf} /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -98,7 +99,8 @@ private[spark] object ThreadingTest { val blockManagerMaster = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf) + "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, + new SecurityManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 1f048a84cdfb6..e0555ca7ac02f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -18,18 +18,23 @@ package org.apache.spark.ui import java.net.InetSocketAddress -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import java.net.URL +import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest} import scala.annotation.tailrec import scala.util.{Failure, Success, Try} import scala.xml.Node -import net.liftweb.json.{JValue, pretty, render} -import org.eclipse.jetty.server.{Handler, Request, Server} -import org.eclipse.jetty.server.handler.{AbstractHandler, ContextHandler, HandlerList, ResourceHandler} +import org.json4s.JValue +import org.json4s.jackson.JsonMethods.{pretty, render} + +import org.eclipse.jetty.server.{DispatcherType, Server} +import org.eclipse.jetty.server.handler.HandlerList +import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.apache.spark.Logging +import org.apache.spark.{Logging, SecurityManager, SparkConf} + /** Utilities for launching a web server using Jetty's HTTP Server class */ private[spark] object JettyUtils extends Logging { @@ -38,57 +43,107 @@ private[spark] object JettyUtils extends Logging { type Responder[T] = HttpServletRequest => T - // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler = - createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) + class ServletParams[T <% AnyRef](val responder: Responder[T], + val contentType: String, + val extractFn: T => String = (in: Any) => in.toString) {} + + // Conversions from various types of Responder's to appropriate servlet parameters + implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] = + new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in))) - implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler = - createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString) + implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] = + new ServletParams(responder, "text/html", (in: Seq[Node]) => "" + in.toString) - implicit def textResponderToHandler(responder: Responder[String]): Handler = - createHandler(responder, "text/plain") + implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] = + new ServletParams(responder, "text/plain") - def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, - extractFn: T => String = (in: Any) => in.toString): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createServlet[T <% AnyRef](servletParams: ServletParams[T], + securityMgr: SecurityManager): HttpServlet = { + new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setContentType("%s;charset=utf-8".format(contentType)) - response.setStatus(HttpServletResponse.SC_OK) - baseRequest.setHandled(true) - val result = responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter().println(extractFn(result)) + if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) { + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = servletParams.responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter().println(servletParams.extractFn(result)) + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page."); + } } } } + def createServletHandler(path: String, servlet: HttpServlet): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler + } + /** Creates a handler that always redirects the user to a given path */ - def createRedirectHandler(newPath: String): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createRedirectHandler(newPath: String, path: String): ServletContextHandler = { + val servlet = new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setStatus(302) - response.setHeader("Location", baseRequest.getRootURL + newPath) - baseRequest.setHandled(true) + // make sure we don't end up with // in the middle + val newUri = new URL(new URL(request.getRequestURL.toString), newPath).toURI + response.sendRedirect(newUri.toString) } } + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler } /** Creates a handler for serving files from a static directory */ - def createStaticHandler(resourceBase: String): ResourceHandler = { - val staticHandler = new ResourceHandler + def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val staticHandler = new DefaultServlet + val holder = new ServletHolder(staticHandler) Option(getClass.getClassLoader.getResource(resourceBase)) match { case Some(res) => - staticHandler.setResourceBase(res.toString) + holder.setInitParameter("resourceBase", res.toString) + holder.setInitParameter("welcomeServlets", "false") + holder.setInitParameter("pathInfoOnly", "false") case None => throw new Exception("Could not find resource path for Web UI: " + resourceBase) } - staticHandler + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler + } + + private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) { + val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim()) + filters.foreach { + case filter : String => + if (!filter.isEmpty) { + logInfo("Adding filter: " + filter) + val holder : FilterHolder = new FilterHolder() + holder.setClassName(filter) + // get any parameters for each filter + val paramName = "spark." + filter + ".params" + val params = conf.get(paramName, "").split(',').map(_.trim()).toSet + params.foreach { + case param : String => + if (!param.isEmpty) { + val parts = param.split("=") + if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) + } + } + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, + DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) + handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } + } + } } /** @@ -98,17 +153,12 @@ private[spark] object JettyUtils extends Logging { * If the desired port number is contented, continues incrementing ports until a free port is * found. Returns the chosen port and the jetty Server object. */ - def startJettyServer(hostName: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) - = { - - val handlersToRegister = handlers.map { case(path, handler) => - val contextHandler = new ContextHandler(path) - contextHandler.setHandler(handler) - contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler] - } + def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler], + conf: SparkConf): (Server, Int) = { + addFilters(handlers, conf) val handlerList = new HandlerList - handlerList.setHandlers(handlersToRegister.toArray) + handlerList.setHandlers(handlers.toArray) @tailrec def connect(currentPort: Int): (Server, Int) = { @@ -118,7 +168,9 @@ private[spark] object JettyUtils extends Logging { server.setThreadPool(pool) server.setHandler(handlerList) - Try { server.start() } match { + Try { + server.start() + } match { case s: Success[_] => (server, server.getConnectors.head.getLocalPort) case f: Failure[_] => diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index af6b65860e006..5f0dee64fedb7 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,7 +17,10 @@ package org.apache.spark.ui -import org.eclipse.jetty.server.{Handler, Server} +import javax.servlet.http.HttpServletRequest + +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.ui.JettyUtils._ @@ -34,9 +37,9 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { var boundPort: Option[Int] = None var server: Option[Server] = None - val handlers = Seq[(String, Handler)]( - ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)), - ("/", createRedirectHandler("/stages")) + val handlers = Seq[ServletContextHandler] ( + createStaticHandler(SparkUI.STATIC_RESOURCE_DIR + "/static", "/static"), + createRedirectHandler("/stages", "/") ) val storage = new BlockManagerUI(sc) val jobs = new JobProgressUI(sc) @@ -52,7 +55,7 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { /** Bind the HTTP server which backs this web interface */ def bind() { try { - val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers) + val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers, sc.conf) logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort)) server = Some(srv) boundPort = Some(usedPort) @@ -83,5 +86,5 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { private[spark] object SparkUI { val DEFAULT_PORT = "4040" - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala index 9e7cdc88162e8..14333476c0e31 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._ import scala.util.Properties import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkContext import org.apache.spark.ui.JettyUtils._ @@ -32,8 +32,9 @@ import org.apache.spark.ui.UIUtils private[spark] class EnvironmentUI(sc: SparkContext) { - def getHandlers = Seq[(String, Handler)]( - ("/environment", (request: HttpServletRequest) => envDetails(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/environment", + createServlet((request: HttpServletRequest) => envDetails(request), sc.env.securityManager)) ) def envDetails(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 1f3b7a4c231b6..4235cfeff9fa2 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{ExceptionFailure, Logging, SparkContext} import org.apache.spark.executor.TaskMetrics @@ -43,8 +43,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { sc.addSparkListener(listener) } - def getHandlers = Seq[(String, Handler)]( - ("/executors", (request: HttpServletRequest) => render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/executors", createServlet((request: HttpServletRequest) => render + (request), sc.env.securityManager)) ) def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index 557bce6b66353..2d95d47e154cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest import scala.Seq import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkContext import org.apache.spark.ui.JettyUtils._ @@ -45,9 +46,15 @@ private[spark] class JobProgressUI(val sc: SparkContext) { def formatDuration(ms: Long) = Utils.msDurationToString(ms) - def getHandlers = Seq[(String, Handler)]( - ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), - ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), - ("/stages", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/stages/stage", + createServlet((request: HttpServletRequest) => stagePage.render(request), + sc.env.securityManager)), + createServletHandler("/stages/pool", + createServlet((request: HttpServletRequest) => poolPage.render(request), + sc.env.securityManager)), + createServletHandler("/stages", + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index dc18eab74e0da..cb2083eb019bf 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.storage import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext} import org.apache.spark.ui.JettyUtils._ @@ -29,8 +29,12 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { val indexPage = new IndexPage(this) val rddPage = new RDDPage(this) - def getHandlers = Seq[(String, Handler)]( - ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), - ("/storage", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/storage/rdd", + createServlet((request: HttpServletRequest) => rddPage.render(request), + sc.env.securityManager)), + createServletHandler("/storage", + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index f26ed47e58046..d0ff17db632c1 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -24,12 +24,12 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem} import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. */ -private[spark] object AkkaUtils { +private[spark] object AkkaUtils extends Logging { /** * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the @@ -42,14 +42,14 @@ private[spark] object AkkaUtils { * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, - conf: SparkConf): (ActorSystem, Int) = { + conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) val akkaTimeout = conf.getInt("spark.akka.timeout", 100) - val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10) + val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" if (!akkaLogLifecycleEvents) { @@ -65,6 +65,15 @@ private[spark] object AkkaUtils { conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) + val secretKey = securityManager.getSecretKey() + val isAuthOn = securityManager.isAuthenticationEnabled() + if (isAuthOn && secretKey == null) { + throw new Exception("Secret key is null with authentication on") + } + val requireCookie = if (isAuthOn) "on" else "off" + val secureCookie = if (isAuthOn) secretKey else "" + logDebug("In createActorSystem, requireCookie is: " + requireCookie) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( ConfigFactory.parseString( s""" @@ -72,6 +81,8 @@ private[spark] object AkkaUtils { |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] |akka.stdout-loglevel = "ERROR" |akka.jvm-exit-on-fatal-error = off + |akka.remote.require-cookie = "$requireCookie" + |akka.remote.secure-cookie = "$secureCookie" |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector @@ -81,7 +92,7 @@ private[spark] object AkkaUtils { |akka.remote.netty.tcp.port = $port |akka.remote.netty.tcp.tcp-nodelay = on |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s - |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB + |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B |akka.remote.netty.tcp.execution-pool-size = $akkaThreads |akka.actor.default-dispatcher.throughput = $akkaBatchSize |akka.log-config-on-start = $logAkkaConfig @@ -110,4 +121,9 @@ private[spark] object AkkaUtils { def lookupTimeout(conf: SparkConf): FiniteDuration = { Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds") } + + /** Returns the configured max frame size for Akka messages in bytes. */ + def maxFrameSizeBytes(conf: SparkConf): Int = { + conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024 + } } diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 681d0a30cb3f8..a8d20ee332355 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.Map import scala.collection.mutable.Set -import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} -import org.objectweb.asm.Opcodes._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ import org.apache.spark.Logging diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala index bf71882ef770a..c539d2f708f95 100644 --- a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala +++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala @@ -23,9 +23,9 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.typesafe.config.Config /** - * An [[akka.actor.ActorSystem]] which refuses to shut down in the event of a fatal exception. + * An akka.actor.ActorSystem which refuses to shut down in the event of a fatal exception * This is necessary as Spark Executors are allowed to recover from fatal exceptions - * (see [[org.apache.spark.executor.Executor]]). + * (see org.apache.spark.executor.Executor) */ object IndestructibleActorSystem { def apply(name: String, config: Config): ActorSystem = diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index b053266f12748..2c1a6f8fd0a44 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -25,10 +25,20 @@ package org.apache.spark.util * @param _2 Element 2 of this MutablePair */ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, - @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] + @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] (var _1: T1, var _2: T2) extends Product2[T1, T2] { + /** No-arg constructor for serialization */ + def this() = this(null.asInstanceOf[T1], null.asInstanceOf[T2]) + + /** Updates this pair with new values and returns itself */ + def update(n1: T1, n2: T2): MutablePair[T1, T2] = { + _1 = n1 + _2 = n2 + this + } + override def toString = "(" + _1 + "," + _2 + ")" override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index 5b0d2c36510b8..732748a7ff82b 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -19,9 +19,9 @@ package org.apache.spark.util /** * A class for tracking the statistics of a set of numbers (count, mean and variance) in a - * numerically robust way. Includes support for merging two StatCounters. Based on - * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance - * Welford and Chan's algorithms for running variance]]. + * numerically robust way. Includes support for merging two StatCounters. Based on Welford + * and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]] + * for running variance. * * @constructor Initialize the StatCounter with the given values. */ @@ -29,6 +29,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { private var n: Long = 0 // Running count of our values private var mu: Double = 0 // Running mean of our values private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) + private var maxValue: Double = Double.NegativeInfinity // Running max of our values + private var minValue: Double = Double.PositiveInfinity // Running min of our values merge(values) @@ -41,6 +43,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { n += 1 mu += delta / n m2 += delta * (value - mu) + maxValue = math.max(maxValue, value) + minValue = math.min(minValue, value) this } @@ -58,7 +62,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (n == 0) { mu = other.mu m2 = other.m2 - n = other.n + n = other.n + maxValue = other.maxValue + minValue = other.minValue } else if (other.n != 0) { val delta = other.mu - mu if (other.n * 10 < n) { @@ -70,6 +76,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { } m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) n += other.n + maxValue = math.max(maxValue, other.maxValue) + minValue = math.min(minValue, other.minValue) } this } @@ -81,6 +89,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { other.n = n other.mu = mu other.m2 = m2 + other.maxValue = maxValue + other.minValue = minValue other } @@ -90,6 +100,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { def sum: Double = n * mu + def max: Double = maxValue + + def min: Double = minValue + /** Return the variance of the values. */ def variance: Double = { if (n == 0) { @@ -121,7 +135,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { def sampleStdev: Double = math.sqrt(sampleVariance) override def toString: String = { - "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) + "(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8e69f1d3351b5..38a275d438959 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL} +import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection} import java.nio.ByteBuffer import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} @@ -33,10 +33,11 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil + /** * Various utility methods used by Spark. */ @@ -232,6 +233,22 @@ private[spark] object Utils extends Logging { } } + /** + * Construct a URI container information used for authentication. + * This also sets the default authenticator to properly negotiation the + * user/password based on the URI. + * + * Note this relies on the Authenticator.setDefault being set properly to decode + * the user name and password. This is currently set in the SecurityManager. + */ + def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = { + val userCred = securityMgr.getSecretKey() + if (userCred == null) throw new Exception("Secret key is null with authentication on") + val userInfo = securityMgr.getHttpUser() + ":" + userCred + new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(), + uri.getQuery(), uri.getFragment()) + } + /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. @@ -239,7 +256,7 @@ private[spark] object Utils extends Logging { * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - def fetchFile(url: String, targetDir: File, conf: SparkConf) { + def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) { val filename = url.split("/").last val tempDir = getLocalDir(conf) val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) @@ -249,7 +266,23 @@ private[spark] object Utils extends Logging { uri.getScheme match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) - val in = new URL(url).openStream() + + var uc: URLConnection = null + if (securityMgr.isAuthenticationEnabled()) { + logDebug("fetchFile with security enabled") + val newuri = constructURIForAuthentication(uri, securityMgr) + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + } else { + logDebug("fetchFile not using security") + uc = new URL(url).openConnection() + } + + val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 + uc.setConnectTimeout(timeout) + uc.setReadTimeout(timeout) + uc.connect() + val in = uc.getInputStream(); val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { @@ -503,8 +536,6 @@ private[spark] object Utils extends Logging { /** * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM - * environment variable. */ def memoryStringToMb(str: String): Int = { val lower = str.toLowerCase @@ -688,8 +719,8 @@ private[spark] object Utils extends Logging { new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass) } - def formatSparkCallSite = { - val callSiteInfo = getCallSiteInfo + /** Returns a printable version of the call site info suitable for logs. */ + def formatCallSiteInfo(callSiteInfo: CallSiteInfo = Utils.getCallSiteInfo) = { "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile, callSiteInfo.firstUserLine) } diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index d437c055f33d4..dc4b8f253f259 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -136,7 +136,7 @@ object Vector { /** * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers - * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided. + * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided. */ def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble()) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index ed74a31f05bae..caa06d5b445b4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -60,7 +60,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, - serializer: Serializer = SparkEnv.get.serializerManager.default, + serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager) extends Iterable[(K, C)] with Serializable with Logging { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index ca611b67ed91d..8a4cdea2fa7b1 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -17,8 +17,11 @@ package org.apache.spark.util.random +import java.nio.ByteBuffer import java.util.{Random => JavaRandom} +import scala.util.hashing.MurmurHash3 + import org.apache.spark.util.Utils.timeIt /** @@ -36,8 +39,8 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { def this() = this(System.nanoTime) - private var seed = init - + private var seed = XORShiftRandom.hashSeed(init) + // we need to just override next - this will be called by nextInt, nextDouble, // nextGaussian, nextLong, etc. override protected def next(bits: Int): Int = { @@ -49,13 +52,19 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { } override def setSeed(s: Long) { - seed = s + seed = XORShiftRandom.hashSeed(s) } } /** Contains benchmark method and main method to run benchmark of the RNG */ private[spark] object XORShiftRandom { + /** Hash seeds to have 0/1 bits throughout. */ + private def hashSeed(seed: Long): Long = { + val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() + MurmurHash3.bytesHash(bytes) + } + /** * Main method for running benchmark * @param args takes one argument - the number of random numbers to generate diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 20232e9fbb8d0..40e853c39ca99 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -75,8 +75,9 @@ public int compare(Integer a, Integer b) { else if (a < b) return 1; else return 0; } - }; + } + @SuppressWarnings("unchecked") @Test public void sparkContextUnion() { // Union of non-specialized JavaRDDs @@ -109,6 +110,37 @@ public void sparkContextUnion() { Assert.assertEquals(4, pUnion.count()); } + @SuppressWarnings("unchecked") + @Test + public void intersection() { + List ints1 = Arrays.asList(1, 10, 2, 3, 4, 5); + List ints2 = Arrays.asList(1, 6, 2, 3, 7, 8); + JavaRDD s1 = sc.parallelize(ints1); + JavaRDD s2 = sc.parallelize(ints2); + + JavaRDD intersections = s1.intersection(s2); + Assert.assertEquals(3, intersections.count()); + + ArrayList list = new ArrayList(); + JavaRDD empty = sc.parallelize(list); + JavaRDD emptyIntersection = empty.intersection(s2); + Assert.assertEquals(0, emptyIntersection.count()); + + List doubles = Arrays.asList(1.0, 2.0); + JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD dIntersection = d1.intersection(d2); + Assert.assertEquals(2, dIntersection.count()); + + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(1, 2)); + pairs.add(new Tuple2(3, 4)); + JavaPairRDD p1 = sc.parallelizePairs(pairs); + JavaPairRDD p2 = sc.parallelizePairs(pairs); + JavaPairRDD pIntersection = p1.intersection(p2); + Assert.assertEquals(2, pIntersection.count()); + } + @Test public void sortByKey() { List> pairs = new ArrayList>(); @@ -148,6 +180,7 @@ public void call(String s) { Assert.assertEquals(2, foreachCalls); } + @SuppressWarnings("unchecked") @Test public void lookup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( @@ -179,6 +212,7 @@ public Boolean call(Integer x) { Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds } + @SuppressWarnings("unchecked") @Test public void cogroup() { JavaPairRDD categories = sc.parallelizePairs(Arrays.asList( @@ -197,6 +231,7 @@ public void cogroup() { cogrouped.collect(); } + @SuppressWarnings("unchecked") @Test public void leftOuterJoin() { JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( @@ -243,6 +278,7 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @SuppressWarnings("unchecked") @Test public void foldByKey() { List> pairs = Arrays.asList( @@ -265,6 +301,7 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); } + @SuppressWarnings("unchecked") @Test public void reduceByKey() { List> pairs = Arrays.asList( @@ -320,8 +357,8 @@ public void approximateResults() { public void take() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Assert.assertEquals(1, rdd.first().intValue()); - List firstTwo = rdd.take(2); - List sample = rdd.takeSample(false, 2, 42); + rdd.take(2); + rdd.takeSample(false, 2, 42); } @Test @@ -359,8 +396,8 @@ public Boolean call(Double x) { Assert.assertEquals(2.49444, rdd.stdev(), 0.01); Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01); - Double first = rdd.first(); - List take = rdd.take(5); + rdd.first(); + rdd.take(5); } @Test @@ -380,14 +417,14 @@ public void javaDoubleRDDHistoGram() { @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { + JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { @Override - public Double call(Integer x) { + public double call(Integer x) { return 1.0 * x; } }).cache(); doubles.collect(); - JavaPairRDD pairs = rdd.map(new PairFunction() { + JavaPairRDD pairs = rdd.mapToPair(new PairFunction() { @Override public Tuple2 call(Integer x) { return new Tuple2(x, x); @@ -416,7 +453,7 @@ public Iterable call(String x) { Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); - JavaPairRDD pairs = rdd.flatMap( + JavaPairRDD pairs = rdd.flatMapToPair( new PairFlatMapFunction() { @Override @@ -430,7 +467,7 @@ public Iterable> call(String s) { Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); Assert.assertEquals(11, pairs.count()); - JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction() { + JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { @Override public Iterable call(String s) { List lengths = new LinkedList(); @@ -438,11 +475,11 @@ public Iterable call(String s) { return lengths; } }); - Double x = doubles.first(); - Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01); + Assert.assertEquals(5.0, doubles.first(), 0.01); Assert.assertEquals(11, pairs.count()); } + @SuppressWarnings("unchecked") @Test public void mapsFromPairsToPairs() { List> pairs = Arrays.asList( @@ -453,7 +490,7 @@ public void mapsFromPairsToPairs() { JavaPairRDD pairRDD = sc.parallelizePairs(pairs); // Regression test for SPARK-668: - JavaPairRDD swapped = pairRDD.flatMap( + JavaPairRDD swapped = pairRDD.flatMapToPair( new PairFlatMapFunction, String, Integer>() { @Override public Iterable> call(Tuple2 item) throws Exception { @@ -463,7 +500,7 @@ public Iterable> call(Tuple2 item) thro swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.map(new PairFunction, String, Integer>() { + pairRDD.mapToPair(new PairFunction, String, Integer>() { @Override public Tuple2 call(Tuple2 item) throws Exception { return item.swap(); @@ -509,6 +546,7 @@ public void repartition() { } } + @SuppressWarnings("unchecked") @Test public void persist() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); @@ -573,6 +611,7 @@ public void textFilesCompressed() throws IOException { Assert.assertEquals(expected, readRDD.collect()); } + @SuppressWarnings("unchecked") @Test public void sequenceFile() { File tempDir = Files.createTempDir(); @@ -584,7 +623,7 @@ public void sequenceFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.map(new PairFunction, IntWritable, Text>() { + rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); @@ -593,7 +632,7 @@ public Tuple2 call(Tuple2 pair) { // Try reading the output back as an object file JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, - Text.class).map(new PairFunction, Integer, String>() { + Text.class).mapToPair(new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 pair) { return new Tuple2(pair._1().get(), pair._2().toString()); @@ -602,6 +641,7 @@ public Tuple2 call(Tuple2 pair) { Assert.assertEquals(pairs, readRDD.collect()); } + @SuppressWarnings("unchecked") @Test public void writeWithNewAPIHadoopFile() { File tempDir = Files.createTempDir(); @@ -613,7 +653,7 @@ public void writeWithNewAPIHadoopFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.map(new PairFunction, IntWritable, Text>() { + rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); @@ -632,6 +672,7 @@ public String call(Tuple2 x) { }).collect().toString()); } + @SuppressWarnings("unchecked") @Test public void readWithNewAPIHadoopFile() throws IOException { File tempDir = Files.createTempDir(); @@ -643,7 +684,7 @@ public void readWithNewAPIHadoopFile() throws IOException { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.map(new PairFunction, IntWritable, Text>() { + rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); @@ -674,6 +715,7 @@ public void objectFilesOfInts() { Assert.assertEquals(expected, readRDD.collect()); } + @SuppressWarnings("unchecked") @Test public void objectFilesOfComplexTypes() { File tempDir = Files.createTempDir(); @@ -690,6 +732,7 @@ public void objectFilesOfComplexTypes() { Assert.assertEquals(pairs, readRDD.collect()); } + @SuppressWarnings("unchecked") @Test public void hadoopFile() { File tempDir = Files.createTempDir(); @@ -701,7 +744,7 @@ public void hadoopFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.map(new PairFunction, IntWritable, Text>() { + rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); @@ -719,6 +762,7 @@ public String call(Tuple2 x) { }).collect().toString()); } + @SuppressWarnings("unchecked") @Test public void hadoopFileCompressed() { File tempDir = Files.createTempDir(); @@ -730,7 +774,7 @@ public void hadoopFileCompressed() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.map(new PairFunction, IntWritable, Text>() { + rdd.mapToPair(new PairFunction, IntWritable, Text>() { @Override public Tuple2 call(Tuple2 pair) { return new Tuple2(new IntWritable(pair._1()), new Text(pair._2())); @@ -753,9 +797,9 @@ public String call(Tuple2 x) { @Test public void zip() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.map(new DoubleFunction() { + JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { @Override - public Double call(Integer x) { + public double call(Integer x) { return 1.0 * x; } }); @@ -824,7 +868,7 @@ public Float zero(Float initialValue) { } }; - final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); rdd.foreach(new VoidFunction() { public void call(Integer x) { floatAccum.add((float) x); @@ -876,16 +920,17 @@ public void checkpointAndRestore() { Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); } + @SuppressWarnings("unchecked") @Test public void mapOnPairRDD() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); - JavaPairRDD rdd2 = rdd1.map(new PairFunction() { + JavaPairRDD rdd2 = rdd1.mapToPair(new PairFunction() { @Override public Tuple2 call(Integer i) throws Exception { return new Tuple2(i, i % 2); } }); - JavaPairRDD rdd3 = rdd2.map( + JavaPairRDD rdd3 = rdd2.mapToPair( new PairFunction, Integer, Integer>() { @Override public Tuple2 call(Tuple2 in) throws Exception { @@ -900,11 +945,12 @@ public Tuple2 call(Tuple2 in) throws Excepti } + @SuppressWarnings("unchecked") @Test public void collectPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); - JavaPairRDD rdd2 = rdd1.map(new PairFunction() { + JavaPairRDD rdd2 = rdd1.mapToPair(new PairFunction() { @Override public Tuple2 call(Integer i) throws Exception { return new Tuple2(i, i % 2); @@ -968,14 +1014,14 @@ public void countApproxDistinctByKey() { @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 - JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[] { 1 })); - JavaPairRDD pairRDD = rdd.map(new PairFunction() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1)); + JavaPairRDD pairRDD = rdd.mapToPair(new PairFunction() { @Override public Tuple2 call(Integer x) throws Exception { return new Tuple2(x, new int[] { x }); } }); pairRDD.collect(); // Works fine - Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + pairRDD.collectAsMap(); // Used to crash with ClassCastException } } diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala new file mode 100644 index 0000000000000..d2e303d81c4c8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import akka.actor._ +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AkkaUtils +import scala.concurrent.Await + +/** + * Test the AkkaUtils with various security settings. + */ +class AkkaUtilsSuite extends FunSuite with LocalSparkContext { + + test("remote fetch security bad password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val securityManager = new SecurityManager(conf); + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf) + + assert(securityManagerBad.isAuthenticationEnabled() === true) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = conf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "bad") + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "good") + val securityManagerBad = new SecurityManager(badconf); + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = badconf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security pass") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val goodconf = new SparkConf + goodconf.set("spark.authenticate", "true") + goodconf.set("spark.authenticate.secret", "good") + val securityManagerGood = new SecurityManager(goodconf); + + assert(securityManagerGood.isAuthenticationEnabled() === true) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = goodconf, securityManager = securityManagerGood) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security on and passwords match + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off client") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val securityManager = new SecurityManager(conf); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf); + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = badconf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + +} diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..96ba3929c1685 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { + override def afterEach() { super.afterEach() System.clearProperty("spark.broadcast.factory") diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala new file mode 100644 index 0000000000000..80f7ec00c74b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import java.nio._ + +import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId} +import scala.concurrent.Await +import scala.concurrent.TimeoutException +import scala.concurrent.duration._ + + +/** + * Test the ConnectionManager with various security settings. + */ +class ConnectionManagerSuite extends FunSuite { + + test("security default off") { + val conf = new SparkConf + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var receivedMessage = false + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + receivedMessage = true + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(manager.id, bufferMessage) + + assert(receivedMessage == true) + + manager.stop() + } + + test("security on same password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + val managerServer = new ConnectionManager(0, conf, securityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val count = 10 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + }) + + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "good") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 1).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + assert(false) + } catch { + case e: TimeoutException => { + // we should timeout here since the client can't do the negotiation + assert(true) + } + } + }) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + manager.stop() + managerServer.stop() + } + + test("security auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 10).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + if (!g.isDefined) assert(false) else assert(true) + } catch { + case e: Exception => { + assert(false) + } + } + }) + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + + +} + diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index e0e8011278649..9cbdfc54a3dc8 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 9be67b3c95abd..aee9ab9091dac 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -30,6 +30,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ + override def beforeEach() { + super.beforeEach() + resetSparkContext() + System.setProperty("spark.authenticate", "false") + } + override def beforeAll() { super.beforeAll() val tmpDir = new File(Files.createTempDir(), "test") @@ -43,6 +49,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarFile = new File(tmpDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) + System.setProperty("spark.authenticate", "false") val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) @@ -77,6 +84,25 @@ class FileServerSuite extends FunSuite with LocalSparkContext { assert(result.toSet === Set((1,200), (2,300), (3,500))) } + test("Distributing files locally security On") { + val sparkConf = new SparkConf(false) + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "good") + sc = new SparkContext("local[4]", "test", sparkConf) + + sc.addFile(tmpFile.toString) + assert(sc.env.securityManager.isAuthenticationEnabled() === true) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + test("Distributing files locally using URL as input") { // addFile("file:///....") sc = new SparkContext("local[4]", "test") diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 8ff02aef67aa0..01af94077144a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -24,6 +24,9 @@ import scala.io.Source import com.google.common.io.Files import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec +import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, TextOutputFormat} +import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} +import org.apache.hadoop.mapreduce.Job import org.scalatest.FunSuite import org.apache.spark.SparkContext._ @@ -208,4 +211,70 @@ class FileSuite extends FunSuite with LocalSparkContext { assert(rdd.count() === 3) assert(rdd.count() === 3) } + + test ("prevent user from overwriting the empty directory (old Hadoop API)") { + sc = new SparkContext("local", "test") + val tempdir = Files.createTempDir() + val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1) + intercept[FileAlreadyExistsException] { + randomRDD.saveAsTextFile(tempdir.getPath) + } + } + + test ("prevent user from overwriting the non-empty directory (old Hadoop API)") { + sc = new SparkContext("local", "test") + val tempdir = Files.createTempDir() + val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1) + randomRDD.saveAsTextFile(tempdir.getPath + "/output") + assert(new File(tempdir.getPath + "/output/part-00000").exists() === true) + intercept[FileAlreadyExistsException] { + randomRDD.saveAsTextFile(tempdir.getPath + "/output") + } + } + + test ("prevent user from overwriting the empty directory (new Hadoop API)") { + sc = new SparkContext("local", "test") + val tempdir = Files.createTempDir() + val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + intercept[FileAlreadyExistsException] { + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempdir.getPath) + } + } + + test ("prevent user from overwriting the non-empty directory (new Hadoop API)") { + sc = new SparkContext("local", "test") + val tempdir = Files.createTempDir() + val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempdir.getPath + "/output") + assert(new File(tempdir.getPath + "/output/part-r-00000").exists() === true) + intercept[FileAlreadyExistsException] { + randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempdir.getPath) + } + } + + test ("save Hadoop Dataset through old Hadoop API") { + sc = new SparkContext("local", "test") + val tempdir = Files.createTempDir() + val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val job = new JobConf() + job.setOutputKeyClass(classOf[String]) + job.setOutputValueClass(classOf[String]) + job.set("mapred.output.format.class", classOf[TextOutputFormat[String, String]].getName) + job.set("mapred.output.dir", tempdir.getPath + "/outputDataset_old") + randomRDD.saveAsHadoopDataset(job) + assert(new File(tempdir.getPath + "/outputDataset_old/part-00000").exists() === true) + } + + test ("save Hadoop Dataset through new Hadoop API") { + sc = new SparkContext("local", "test") + val tempdir = Files.createTempDir() + val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) + val job = new Job(sc.hadoopConfiguration) + job.setOutputKeyClass(classOf[String]) + job.setOutputValueClass(classOf[String]) + job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) + job.getConfiguration.set("mapred.output.dir", tempdir.getPath + "/outputDataset_new") + randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration) + assert(new File(tempdir.getPath + "/outputDataset_new/part-r-00000").exists() === true) + } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6c1e325f6f348..a5bd72eb0a122 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.concurrent.Await import akka.actor._ +import akka.testkit.TestActorRef import org.scalatest.FunSuite import org.apache.spark.scheduler.MapStatus @@ -51,14 +52,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master start and stop") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = + actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.stop() } test("master register and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = + actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -77,7 +80,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master register and unregister and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = + actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -98,14 +102,18 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf) - System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, + securityManager = new SecurityManager(conf)) + + // Will be cleared by LocalSparkContext + System.setProperty("spark.driver.port", boundPort.toString) val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf) + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, + securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") @@ -124,7 +132,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) masterTracker.incrementEpoch() @@ -134,4 +142,44 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // failure should be cached intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } } + + test("remote fetch below akka frame size") { + val newConf = new SparkConf + newConf.set("spark.akka.frameSize", "1") + newConf.set("spark.akka.askTimeout", "1") // Fail fast + + val masterTracker = new MapOutputTrackerMaster(conf) + val actorSystem = ActorSystem("test") + val actorRef = TestActorRef[MapOutputTrackerMasterActor]( + new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + val masterActor = actorRef.underlyingActor + + // Frame size should be ~123B, and no exception should be thrown + masterTracker.registerShuffle(10, 1) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0))) + masterActor.receive(GetMapOutputStatuses(10)) + } + + test("remote fetch exceeds akka frame size") { + val newConf = new SparkConf + newConf.set("spark.akka.frameSize", "1") + newConf.set("spark.akka.askTimeout", "1") // Fail fast + + val masterTracker = new MapOutputTrackerMaster(conf) + val actorSystem = ActorSystem("test") + val actorRef = TestActorRef[MapOutputTrackerMasterActor]( + new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + val masterActor = actorRef.underlyingActor + + // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Note that the size is hand-selected here because map output statuses are compressed before + // being sent. + masterTracker.registerShuffle(20, 100) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new MapStatus( + BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0))) + } + intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + } } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 4305686d3a6d5..996db70809320 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -171,6 +171,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(abs(6.0/2 - rdd.mean) < 0.01) assert(abs(1.0 - rdd.variance) < 0.01) assert(abs(1.0 - rdd.stdev) < 0.01) + assert(stats.max === 4.0) + assert(stats.min === 2.0) // Add other tests here for classes that should be able to handle empty partitions correctly } diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala index 3a0385a1b0bd9..0bac78d8a6bdf 100644 --- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala @@ -19,74 +19,152 @@ package org.apache.spark import org.scalatest.FunSuite + +import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition} +import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit} +import org.apache.hadoop.fs.Path + +import scala.collection.Map +import scala.sys.process._ +import scala.util.Try +import org.apache.hadoop.io.{Text, LongWritable} + class PipedRDDSuite extends FunSuite with SharedSparkContext { test("basic pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + if (testCommandAvailable("cat")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat")) + val piped = nums.pipe(Seq("cat")) - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + } else { + assert(true) + } } test("advanced pipe") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val bl = sc.broadcast(List("0")) - - val piped = nums.pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Int, f: String=> Unit) => f(i + "_")) - - val c = piped.collect() - - assert(c.size === 8) - assert(c(0) === "0") - assert(c(1) === "\u0001") - assert(c(2) === "1_") - assert(c(3) === "2_") - assert(c(4) === "0") - assert(c(5) === "\u0001") - assert(c(6) === "3_") - assert(c(7) === "4_") - - val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str=>str.split("\t")(0)). - pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => {bl.value.map(f(_));f("\u0001")}, - (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect() - assert(d.size === 8) - assert(d(0) === "0") - assert(d(1) === "\u0001") - assert(d(2) === "b\t2_") - assert(d(3) === "b\t4_") - assert(d(4) === "0") - assert(d(5) === "\u0001") - assert(d(6) === "a\t1_") - assert(d(7) === "a\t3_") + if (testCommandAvailable("cat")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => { + bl.value.map(f(_)); f("\u0001") + }, + (i: Int, f: String => Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str => str.split("\t")(0)). + pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => { + bl.value.map(f(_)); f("\u0001") + }, + (i: Tuple2[String, Seq[String]], f: String => Unit) => { + for (e <- i._2) { + f(e + "_") + } + }).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") + } else { + assert(true) + } } test("pipe with env variable") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) - val c = piped.collect() - assert(c.size === 2) - assert(c(0) === "LALALA") - assert(c(1) === "LALALA") + if (testCommandAvailable("printenv")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.size === 2) + assert(c(0) === "LALALA") + assert(c(1) === "LALALA") + } else { + assert(true) + } } test("pipe with non-zero exit status") { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) - intercept[SparkException] { - piped.collect() + if (testCommandAvailable("cat")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null")) + intercept[SparkException] { + piped.collect() + } + } else { + assert(true) } } + test("test pipe exports map_input_file") { + testExportInputFile("map_input_file") + } + + test("test pipe exports mapreduce_map_input_file") { + testExportInputFile("mapreduce_map_input_file") + } + + def testCommandAvailable(command: String): Boolean = { + Try(Process(command) !!).isSuccess + } + + def testExportInputFile(varName: String) { + if (testCommandAvailable("printenv")) { + val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable], + classOf[Text], 2) { + override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition()) + + override val getDependencies = List[Dependency[_]]() + + override def compute(theSplit: Partition, context: TaskContext) = { + new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1), + new Text("b")))) + } + } + val hadoopPart1 = generateFakeHadoopPartition() + val pipedRdd = new PipedRDD(nums, "printenv " + varName) + val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) + val rddIter = pipedRdd.compute(hadoopPart1, tContext) + val arr = rddIter.toArray + assert(arr(0) == "/some/path") + } else { + // printenv isn't available so just pass the test + assert(true) + } + } + + def generateFakeHadoopPartition(): HadoopPartition = { + val split = new FileSplit(new Path("/some/path"), 0, 1, + Array[String]("loc1", "loc2", "loc3", "loc4", "loc5")) + new HadoopPartition(sc.newRddId(), 1, split) + } + } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index abea36f7c83df..be6508a40ea61 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -27,6 +27,9 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.MutablePair class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + + val conf = new SparkConf(loadDefaults = false) + test("groupByKey without compression") { try { System.setProperty("spark.shuffle.compress", "false") @@ -54,7 +57,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[KryoSerializer].getName) + b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId assert(c.count === 10) @@ -76,7 +79,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(3)).setSerializer(classOf[KryoSerializer].getName) + b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf)) assert(c.count === 10) } @@ -92,7 +95,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - .setSerializer(classOf[KryoSerializer].getName) + .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId assert(c.count === 4) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index f28d5c7b133b3..3bb936790d506 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -95,6 +95,10 @@ class SparkContextSchedulerCreationSuite } } + test("yarn-cluster") { + testYarn("yarn-cluster", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + } + test("yarn-standalone") { testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index de866ed7ffed8..bae3b37e267d5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.deploy import java.io.File import java.util.Date -import net.liftweb.json.Diff -import net.liftweb.json.{JsonAST, JsonParser} -import net.liftweb.json.JsonAST.{JNothing, JValue} +import org.json4s._ + +import org.json4s.JValue +import org.json4s.jackson.JsonMethods +import com.fasterxml.jackson.core.JsonParseException + import org.scalatest.FunSuite import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} @@ -34,31 +37,31 @@ class JsonProtocolSuite extends FunSuite { test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) assertValidJson(output) - assertValidDataInJson(output, JsonParser.parse(JsonConstants.appInfoJsonStr)) + assertValidDataInJson(output, JsonMethods.parse(JsonConstants.appInfoJsonStr)) } test("writeWorkerInfo") { val output = JsonProtocol.writeWorkerInfo(createWorkerInfo()) assertValidJson(output) - assertValidDataInJson(output, JsonParser.parse(JsonConstants.workerInfoJsonStr)) + assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerInfoJsonStr)) } test("writeApplicationDescription") { val output = JsonProtocol.writeApplicationDescription(createAppDesc()) assertValidJson(output) - assertValidDataInJson(output, JsonParser.parse(JsonConstants.appDescJsonStr)) + assertValidDataInJson(output, JsonMethods.parse(JsonConstants.appDescJsonStr)) } test("writeExecutorRunner") { val output = JsonProtocol.writeExecutorRunner(createExecutorRunner()) assertValidJson(output) - assertValidDataInJson(output, JsonParser.parse(JsonConstants.executorRunnerJsonStr)) + assertValidDataInJson(output, JsonMethods.parse(JsonConstants.executorRunnerJsonStr)) } test("writeDriverInfo") { val output = JsonProtocol.writeDriverInfo(createDriverInfo()) assertValidJson(output) - assertValidDataInJson(output, JsonParser.parse(JsonConstants.driverInfoJsonStr)) + assertValidDataInJson(output, JsonMethods.parse(JsonConstants.driverInfoJsonStr)) } test("writeMasterState") { @@ -71,7 +74,7 @@ class JsonProtocolSuite extends FunSuite { activeDrivers, completedDrivers, RecoveryState.ALIVE) val output = JsonProtocol.writeMasterState(stateResponse) assertValidJson(output) - assertValidDataInJson(output, JsonParser.parse(JsonConstants.masterStateJsonStr)) + assertValidDataInJson(output, JsonMethods.parse(JsonConstants.masterStateJsonStr)) } test("writeWorkerState") { @@ -83,7 +86,7 @@ class JsonProtocolSuite extends FunSuite { finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) assertValidJson(output) - assertValidDataInJson(output, JsonParser.parse(JsonConstants.workerStateJsonStr)) + assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr)) } def createAppDesc(): ApplicationDescription = { @@ -125,9 +128,9 @@ class JsonProtocolSuite extends FunSuite { def assertValidJson(json: JValue) { try { - JsonParser.parse(JsonAST.compactRender(json)) + JsonMethods.parse(JsonMethods.compact(json)) } catch { - case e: JsonParser.ParseException => fail("Invalid Json detected", e) + case e: JsonParseException => fail("Invalid Json detected", e) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index c1e8b295dfe3b..96a5a1231813e 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -18,21 +18,22 @@ package org.apache.spark.metrics import org.scalatest.{BeforeAndAfter, FunSuite} - -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.MasterSource class MetricsSystemSuite extends FunSuite with BeforeAndAfter { var filePath: String = _ var conf: SparkConf = null + var securityMgr: SecurityManager = null before { filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() conf = new SparkConf(false).set("spark.metrics.conf", filePath) + securityMgr = new SecurityManager(conf) } test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks @@ -42,7 +43,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter { } test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index e3e23775f011d..f9e994b13dfbc 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -347,6 +347,48 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { */ pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored") } + + test("lookup") { + val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + + assert(pairs.partitioner === None) + assert(pairs.lookup(1) === Seq(2)) + assert(pairs.lookup(5) === Seq(6,7)) + assert(pairs.lookup(-1) === Seq()) + + } + + test("lookup with partitioner") { + val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + + val p = new Partitioner { + def numPartitions: Int = 2 + + def getPartition(key: Any): Int = Math.abs(key.hashCode() % 2) + } + val shuffled = pairs.partitionBy(p) + + assert(shuffled.partitioner === Some(p)) + assert(shuffled.lookup(1) === Seq(2)) + assert(shuffled.lookup(5) === Seq(6,7)) + assert(shuffled.lookup(-1) === Seq()) + } + + test("lookup with bad partitioner") { + val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7))) + + val p = new Partitioner { + def numPartitions: Int = 2 + + def getPartition(key: Any): Int = key.hashCode() % 2 + } + val shuffled = pairs.partitionBy(p) + + assert(shuffled.partitioner === Some(p)) + assert(shuffled.lookup(1) === Seq(2)) + intercept[IllegalArgumentException] {shuffled.lookup(-1)} + } + } /* diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 60bcada55245b..d6b5fdc7984b4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -47,6 +47,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4")) assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4))) + assert(nums.max() === 4) + assert(nums.min() === 1) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) @@ -457,6 +459,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) + for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements @@ -488,6 +491,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("takeSample from an empty rdd") { + val emptySet = sc.parallelize(Seq.empty[Int], 2) + val sample = emptySet.takeSample(false, 20, 1) + assert(sample.length === 0) + } + test("randomSplit") { val n = 600 val data = sc.parallelize(1 to n, 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0b90c4e74c8a4..0a7cb69416a08 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -24,3 +24,19 @@ class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int override def preferredLocations: Seq[TaskLocation] = prefLocs } + +object FakeTask { + /** + * Utility method to create a TaskSet, potentially setting a particular sequence of preferred + * locations for each task (given as varargs) if this sequence is not empty. + */ + def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + } + new TaskSet(tasks, 0, 0, 0, null) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 368c5154ea3b9..7c4f2b4361892 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -129,7 +129,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc sm.localBlocksFetched should be > (0) sm.remoteBlocksFetched should be (0) sm.remoteBytesRead should be (0l) - sm.remoteFetchTime should be (0l) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index ac07f60e284bb..c4e7a4bb7d385 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -93,10 +93,10 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA // If this test hangs, it's probably because no resource offers were made after the task // failed. val scheduler: TaskSchedulerImpl = sc.taskScheduler match { - case clusterScheduler: TaskSchedulerImpl => - clusterScheduler + case taskScheduler: TaskSchedulerImpl => + taskScheduler case _ => - assert(false, "Expect local cluster to use ClusterScheduler") + assert(false, "Expect local cluster to use TaskSchedulerImpl") throw new ClassCastException } scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala similarity index 66% rename from core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 85e929925e3b5..6b0800af9c6d0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -25,13 +25,20 @@ import org.scalatest.FunSuite import org.apache.spark._ +class FakeSchedulerBackend extends SchedulerBackend { + def start() {} + def stop() {} + def reviveOffers() {} + def defaultParallelism() = 1 +} + class FakeTaskSetManager( initPriority: Int, initStageId: Int, initNumTasks: Int, - clusterScheduler: TaskSchedulerImpl, + taskScheduler: TaskSchedulerImpl, taskSet: TaskSet) - extends TaskSetManager(clusterScheduler, taskSet, 0) { + extends TaskSetManager(taskScheduler, taskSet, 0) { parent = null weight = 1 @@ -105,9 +112,10 @@ class FakeTaskSetManager( } } -class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { +class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = { + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, + taskSet: TaskSet): FakeTaskSetManager = { new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) } @@ -133,20 +141,17 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging } test("FIFO Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new TaskSchedulerImpl(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + val taskSet = FakeTask.createTaskSet(1) val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) schedulableBuilder.buildPools() - val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet) - val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet) - val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet) + val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet) + val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet) + val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager0, null) schedulableBuilder.addTaskSetManager(taskSetManager1, null) schedulableBuilder.addTaskSetManager(taskSetManager2, null) @@ -160,12 +165,9 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging } test("Fair Scheduler Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new TaskSchedulerImpl(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + val taskSet = FakeTask.createTaskSet(1) val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() System.setProperty("spark.scheduler.allocation.file", xmlPath) @@ -189,15 +191,15 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging val properties2 = new Properties() properties2.setProperty("spark.scheduler.pool","2") - val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet) - val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet) - val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet) + val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet) + val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet) + val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager10, properties1) schedulableBuilder.addTaskSetManager(taskSetManager11, properties1) schedulableBuilder.addTaskSetManager(taskSetManager12, properties1) - val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet) - val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet) + val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet) + val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet) schedulableBuilder.addTaskSetManager(taskSetManager23, properties2) schedulableBuilder.addTaskSetManager(taskSetManager24, properties2) @@ -217,12 +219,9 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging } test("Nested Pool Test") { - sc = new SparkContext("local", "ClusterSchedulerSuite") - val clusterScheduler = new TaskSchedulerImpl(sc) - var tasks = ArrayBuffer[Task[_]]() - val task = new FakeTask(0) - tasks += task - val taskSet = new TaskSet(tasks.toArray,0,0,0,null) + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + val taskSet = FakeTask.createTaskSet(1) val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) @@ -240,23 +239,23 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging pool1.addSchedulable(pool10) pool1.addSchedulable(pool11) - val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet) - val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet) + val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet) + val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet) pool00.addSchedulable(taskSetManager000) pool00.addSchedulable(taskSetManager001) - val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet) - val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet) + val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet) + val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet) pool01.addSchedulable(taskSetManager010) pool01.addSchedulable(taskSetManager011) - val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet) - val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet) + val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet) + val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet) pool10.addSchedulable(taskSetManager100) pool10.addSchedulable(taskSetManager101) - val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet) - val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet) + val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet) + val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet) pool11.addSchedulable(taskSetManager110) pool11.addSchedulable(taskSetManager111) @@ -265,4 +264,35 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging checkTaskSetId(rootPool, 6) checkTaskSetId(rootPool, 2) } + + test("Scheduler does not always schedule tasks on the same workers") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + var dagScheduler = new DAGScheduler(taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorGained(execId: String, host: String) {} + } + + val numFreeCores = 1 + val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores), + new WorkerOffer("executor1", "host1", numFreeCores)) + // Repeatedly try to schedule a 1-task job, and make sure that it doesn't always + // get scheduled on the same executor. While there is a chance this test will fail + // because the task randomly gets placed on the first executor all 1000 times, the + // probability of that happening is 2^-1000 (so sufficiently small to be considered + // negligible). + val numTrials = 1000 + val selectedExecutorIds = 1.to(numTrials).map { _ => + val taskSet = FakeTask.createTaskSet(1) + taskScheduler.submitTasks(taskSet) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(1 === taskDescriptions.length) + taskDescriptions(0).executorId + } + var count = selectedExecutorIds.count(_ == workerOffers(0).executorId) + assert(count > 0) + assert(count < numTrials) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 34a7d8cefeea2..33cc7588b919c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.FakeClock -class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) { +class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(taskScheduler) { override def taskStarted(task: Task[_], taskInfo: TaskInfo) { taskScheduler.startedTasks += taskInfo.index } @@ -51,12 +51,12 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler } /** - * A mock ClusterScheduler implementation that just remembers information about tasks started and + * A mock TaskSchedulerImpl implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost * to work, and these are required for locality in TaskSetManager. */ -class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) +class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) extends TaskSchedulerImpl(sc) { val startedTasks = new ArrayBuffer[Long] @@ -87,8 +87,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("TaskSet with no preferences") { sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(1) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // Offer a host with no CPUs @@ -113,8 +113,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("multiple offers with no preferences") { sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(3) + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // First three offers should all find tasks @@ -144,8 +144,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("basic delay scheduling") { sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(4, + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4, Seq(TaskLocation("host1", "exec1")), Seq(TaskLocation("host2", "exec2")), Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), @@ -188,9 +188,9 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with fallback") { sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) - val taskSet = createTaskSet(5, + val taskSet = FakeTask.createTaskSet(5, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(TaskLocation("host2")), @@ -228,8 +228,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("delay scheduling with failed hosts") { sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) - val taskSet = createTaskSet(3, + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(3, Seq(TaskLocation("host1")), Seq(TaskLocation("host2")), Seq(TaskLocation("host3")) @@ -260,8 +260,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("task result lost") { sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(1) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) @@ -277,8 +277,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("repeated failures lead to task set abortion") { sc = new SparkContext("local", "test") - val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) - val taskSet = createTaskSet(1) + val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(1) val clock = new FakeClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) @@ -298,21 +298,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } } - - /** - * Utility method to create a TaskSet, potentially setting a particular sequence of preferred - * locations for each task (given as varargs) if this sequence is not empty. - */ - def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - if (prefLocs.size != 0 && prefLocs.size != numTasks) { - throw new IllegalArgumentException("Wrong number of task locations") - } - val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) - } - new TaskSet(tasks, 0, 0, 0, null) - } - def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9f011d9c8d132..1036b9f34e9dd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} @@ -39,6 +39,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var actorSystem: ActorSystem = null var master: BlockManagerMaster = null var oldArch: String = null + conf.set("spark.authenticate", "false") + val securityMgr = new SecurityManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -49,7 +51,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf, + securityManager = securityMgr) this.actorSystem = actorSystem conf.set("spark.driver.port", boundPort.toString) @@ -125,7 +128,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -155,8 +158,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, + securityMgr) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -171,7 +175,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -219,7 +223,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -253,7 +257,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -269,7 +273,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -288,7 +292,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -325,7 +329,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -344,7 +348,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -363,7 +367,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -382,7 +386,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -405,7 +409,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -418,7 +422,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -433,7 +437,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -448,7 +452,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -463,7 +467,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -478,7 +482,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -503,7 +507,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -527,7 +531,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -573,7 +577,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500, conf) + store = new BlockManager("", actorSystem, master, serializer, 500, conf, securityMgr) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -584,7 +588,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -592,7 +596,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -600,7 +604,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -608,28 +612,28 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -643,7 +647,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. - store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf) + store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, + securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass @@ -657,4 +662,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.getSingle("a1") == None, "a1 should not be in store") } } + + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // Access rdd_1_0 to ensure it's not least recently used. + assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store") + // According to the same-RDD rule, rdd_1_0 should be replaced here. + store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + // rdd_1_0 should have been replaced, even it's not least recently used. + assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store") + assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store") + assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") + } } diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala new file mode 100644 index 0000000000000..bcf138b5ee6d0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.scalatest.FunSuite +import org.apache.spark.{SharedSparkContext, SparkConf, LocalSparkContext, SparkContext} + + +class FlatmapIteratorSuite extends FunSuite with LocalSparkContext { + /* Tests the ability of Spark to deal with user provided iterators from flatMap + * calls, that may generate more data then available memory. In any + * memory based persistance Spark will unroll the iterator into an ArrayBuffer + * for caching, however in the case that the use defines DISK_ONLY persistance, + * the iterator will be fed directly to the serializer and written to disk. + * + * This also tests the ObjectOutputStream reset rate. When serializing using the + * Java serialization system, the serializer caches objects to prevent writing redundant + * data, however that stops GC of those objects. By calling 'reset' you flush that + * info from the serializer, and allow old objects to be GC'd + */ + test("Flatmap Iterator to Disk") { + val sconf = new SparkConf().setMaster("local").setAppName("iterator_to_disk_test") + sc = new SparkContext(sconf) + val expand_size = 100 + val data = sc.parallelize((1 to 5).toSeq). + flatMap( x => Stream.range(0, expand_size)) + var persisted = data.persist(StorageLevel.DISK_ONLY) + assert(persisted.count()===500) + assert(persisted.filter(_==1).count()===5) + } + + test("Flatmap Iterator to Memory") { + val sconf = new SparkConf().setMaster("local").setAppName("iterator_to_disk_test") + sc = new SparkContext(sconf) + val expand_size = 100 + val data = sc.parallelize((1 to 5).toSeq). + flatMap(x => Stream.range(0, expand_size)) + var persisted = data.persist(StorageLevel.MEMORY_ONLY) + assert(persisted.count()===500) + assert(persisted.filter(_==1).count()===5) + } + + test("Serializer Reset") { + val sconf = new SparkConf().setMaster("local").setAppName("serializer_reset_test") + .set("spark.serializer.objectStreamReset", "10") + sc = new SparkContext(sconf) + val expand_size = 500 + val data = sc.parallelize(Seq(1,2)). + flatMap(x => Stream.range(1, expand_size). + map(y => "%d: string test %d".format(y,x))) + var persisted = data.persist(StorageLevel.MEMORY_ONLY_SER) + assert(persisted.filter(_.startsWith("1:")).count()===2) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 20ebb1897e6ba..30415814adbba 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -24,6 +24,8 @@ import scala.util.{Failure, Success, Try} import org.eclipse.jetty.server.Server import org.scalatest.FunSuite +import org.apache.spark.SparkConf + class UISuite extends FunSuite { test("jetty port increases under contention") { val startPort = 4040 @@ -34,15 +36,17 @@ class UISuite extends FunSuite { case Failure(e) => // Either case server port is busy hence setup for test complete } - val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) - val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) + val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) + val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) // Allow some wiggle room in case ports on the machine are under contention assert(boundPort1 > startPort && boundPort1 < startPort + 10) assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) } test("jetty binds to port 0 correctly") { - val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq()) + val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq(), new SparkConf) assert(jettyServer.getState === "STARTED") assert(boundPort != 0) Try {new ServerSocket(boundPort)} match { diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index c51d12bfe0bc6..757476efdb789 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -72,4 +72,8 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { } + test ("XORShift with zero seed") { + val random = new XORShiftRandom(0L) + assert(random.nextInt() != 0) + } } diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md new file mode 100644 index 0000000000000..2437a98672177 --- /dev/null +++ b/dev/audit-release/README.md @@ -0,0 +1,11 @@ +# Test Application Builds +This directory includes test applications which are built when auditing releases. You can +run them locally by setting appropriate environment variables. + +``` +$ cd sbt_app_core +$ SCALA_VERSION=2.10.3 \ + SPARK_VERSION=1.0.0-SNAPSHOT \ + SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ + sbt run +``` diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 4408658f5e33f..52c367d9b030d 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -31,10 +31,10 @@ import urllib2 ## Fill in release details here: -RELEASE_URL = "http://people.apache.org/~pwendell/spark-0.9.0-incubating-rc5/" +RELEASE_URL = "http://people.apache.org/~pwendell/spark-1.0.0-rc1/" RELEASE_KEY = "9E4FE3AF" RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/" -RELEASE_VERSION = "0.9.0-incubating" +RELEASE_VERSION = "1.0.0" SCALA_VERSION = "2.10.3" SCALA_BINARY_VERSION = "2.10" ## @@ -191,10 +191,6 @@ def ensure_path_not_present(x): test("NOTICE" in base_files, "Tarball contains NOTICE file") test("LICENSE" in base_files, "Tarball contains LICENSE file") - os.chdir(os.path.join(WORK_DIR, dir_name)) - readme = "".join(open("README.md").readlines()) - disclaimer_part = "is an effort undergoing incubation" - test(disclaimer_part in readme, "README file contains disclaimer") os.chdir(WORK_DIR) for artifact in artifacts: diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index d49de8b73a856..53fe43215e40e 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -17,6 +17,8 @@ package main.scala +import scala.util.Try + import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ @@ -31,6 +33,17 @@ object SimpleApp { println("Failed to parse log files with Spark") System.exit(-1) } - println("Test succeeded") + + // Regression test for SPARK-1167: Remove metrics-ganglia from default build due to LGPL issue + val foundConsole = Try(Class.forName("org.apache.spark.metrics.sink.ConsoleSink")).isSuccess + val foundGanglia = Try(Class.forName("org.apache.spark.metrics.sink.GangliaSink")).isSuccess + if (!foundConsole) { + println("Console sink not loaded via spark-core") + System.exit(-1) + } + if (foundGanglia) { + println("Ganglia sink was loaded via spark-core") + System.exit(-1) + } } } diff --git a/dev/audit-release/sbt_app_ganglia/build.sbt b/dev/audit-release/sbt_app_ganglia/build.sbt new file mode 100644 index 0000000000000..55db675c722d1 --- /dev/null +++ b/dev/audit-release/sbt_app_ganglia/build.sbt @@ -0,0 +1,31 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +name := "Ganglia Test" + +version := "1.0" + +scalaVersion := System.getenv.get("SCALA_VERSION") + +libraryDependencies += "org.apache.spark" %% "spark-core" % System.getenv.get("SPARK_VERSION") + +libraryDependencies += "org.apache.spark" %% "spark-ganglia-lgpl" % System.getenv.get("SPARK_VERSION") + +resolvers ++= Seq( + "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Akka Repository" at "http://repo.akka.io/releases/", + "Spray Repository" at "http://repo.spray.cc/") diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala similarity index 53% rename from core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala rename to dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala index d0ba0b6307ee9..0be8e64fbfabd 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala +++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala @@ -15,19 +15,25 @@ * limitations under the License. */ -package org.apache.spark.api.java.function +package main.scala -import scala.reflect.ClassTag -import org.apache.spark.api.java.JavaSparkContext +import scala.util.Try -/** - * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs. - */ -// PairFunction does not extend Function because some UDF functions, like map, -// are overloaded for both Function and PairFunction. -abstract class PairFunction[T, K, V] extends WrappedFunction1[T, (K, V)] with Serializable { - - def keyType(): ClassTag[K] = JavaSparkContext.fakeClassTag +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ - def valueType(): ClassTag[V] = JavaSparkContext.fakeClassTag +object SimpleApp { + def main(args: Array[String]) { + // Regression test for SPARK-1167: Remove metrics-ganglia from default build due to LGPL issue + val foundConsole = Try(Class.forName("org.apache.spark.metrics.sink.ConsoleSink")).isSuccess + val foundGanglia = Try(Class.forName("org.apache.spark.metrics.sink.GangliaSink")).isSuccess + if (!foundConsole) { + println("Console sink not loaded via spark-core") + System.exit(-1) + } + if (!foundGanglia) { + println("Ganglia sink not loaded via spark-ganglia-lgpl") + System.exit(-1) + } + } } diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 7cebace5069f8..995106f111443 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -39,8 +39,8 @@ GIT_TAG=v$RELEASE_VERSION # Artifact publishing -git clone https://git-wip-us.apache.org/repos/asf/incubator-spark.git -b $GIT_BRANCH -cd incubator-spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH +cd spark export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" mvn -Pyarn release:clean @@ -49,21 +49,21 @@ mvn -DskipTests \ -Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn \ + -Pyarn -Pspark-ganglia-lgpl \ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ --batch-mode release:prepare mvn -DskipTests \ -Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn \ + -Pyarn -Pspark-ganglia-lgpl\ release:perform -rm -rf incubator-spark +rm -rf spark # Source and binary tarballs -git clone https://git-wip-us.apache.org/repos/asf/incubator-spark.git -cd incubator-spark +git clone https://git-wip-us.apache.org/repos/asf/spark.git +cd spark git checkout --force $GIT_TAG release_hash=`git rev-parse HEAD` @@ -71,7 +71,7 @@ rm .gitignore rm -rf .git cd .. -cp -r incubator-spark spark-$RELEASE_VERSION +cp -r spark spark-$RELEASE_VERSION tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \ --detach-sig spark-$RELEASE_VERSION.tgz @@ -85,7 +85,7 @@ make_binary_release() { NAME=$1 MAVEN_FLAGS=$2 - cp -r incubator-spark spark-$RELEASE_VERSION-bin-$NAME + cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" mvn $MAVEN_FLAGS -DskipTests clean package @@ -118,9 +118,9 @@ scp spark* \ $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/ # Docs -cd incubator-spark +cd spark cd docs -jekyll build +PRODUCTION=1 jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs rsync -r _site/* $USER_NAME@people.apache.org /home/$USER_NAME/public_html/$rc_docs_folder diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 93621c96daf2d..e8f78fc5f231a 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -38,7 +38,7 @@ # Remote name which points to Apache git PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache") -GIT_API_BASE = "https://api.github.com/repos/apache/incubator-spark" +GIT_API_BASE = "https://api.github.com/repos/apache/spark" # Prefix added to temporary branches BRANCH_PREFIX = "PR_TOOL" diff --git a/dev/run-tests b/dev/run-tests index d65a397b4c8c7..cf0b940c09a81 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -27,6 +27,16 @@ rm -rf ./work # Fail fast set -e +if test -x "$JAVA_HOME/bin/java"; then + declare java_cmd="$JAVA_HOME/bin/java" +else + declare java_cmd=java +fi + +JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +[ "$JAVA_VERSION" -ge 18 ] && echo "" || echo "[Warn] Java 8 tests will not run, because JDK version is < 1.8." + + echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" diff --git a/docker/README.md b/docker/README.md index bf59e77d111f9..40ba9c3065946 100644 --- a/docker/README.md +++ b/docker/README.md @@ -2,4 +2,6 @@ Spark docker files =========== Drawn from Matt Massie's docker files (https://github.com/massie/dockerfiles), -as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker). \ No newline at end of file +as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker). + +Tested with Docker version 0.8.1. diff --git a/docker/spark-test/master/default_cmd b/docker/spark-test/master/default_cmd index a5b1303c2ebdb..5a7da3446f6d2 100755 --- a/docker/spark-test/master/default_cmd +++ b/docker/spark-test/master/default_cmd @@ -19,4 +19,10 @@ IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }') echo "CONTAINER_IP=$IP" -/opt/spark/spark-class org.apache.spark.deploy.master.Master -i $IP +export SPARK_LOCAL_IP=$IP +export SPARK_PUBLIC_DNS=$IP + +# Avoid the default Docker behavior of mapping our IP address to an unreachable host name +umount /etc/hosts + +/opt/spark/bin/spark-class org.apache.spark.deploy.master.Master -i $IP diff --git a/docker/spark-test/worker/default_cmd b/docker/spark-test/worker/default_cmd index ab6336f70c1c6..31b06cb0eb047 100755 --- a/docker/spark-test/worker/default_cmd +++ b/docker/spark-test/worker/default_cmd @@ -19,4 +19,10 @@ IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }') echo "CONTAINER_IP=$IP" -/opt/spark/spark-class org.apache.spark.deploy.worker.Worker $1 +export SPARK_LOCAL_IP=$IP +export SPARK_PUBLIC_DNS=$IP + +# Avoid the default Docker behavior of mapping our IP address to an unreachable host name +umount /etc/hosts + +/opt/spark/bin/spark-class org.apache.spark.deploy.worker.Worker $1 diff --git a/docs/README.md b/docs/README.md index cc09d6e88f41e..0678fc5c86706 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,6 +1,6 @@ Welcome to the Spark documentation! -This readme will walk you through navigating and building the Spark documentation, which is included here with the Spark source code. You can also find documentation specific to release versions of Spark at http://spark.incubator.apache.org/documentation.html. +This readme will walk you through navigating and building the Spark documentation, which is included here with the Spark source code. You can also find documentation specific to release versions of Spark at http://spark.apache.org/documentation.html. Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the documentation yourself. Why build it yourself? So that you have the docs that corresponds to whichever version of Spark you currently have checked out of revision control. @@ -10,9 +10,22 @@ We include the Spark documentation as part of the source (as opposed to using a In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can read those text files directly if you want. Start with index.md. -To make things quite a bit prettier and make the links easier to follow, generate the html version of the documentation based on the src directory by running `jekyll build` in the docs directory. Use the command `SKIP_SCALADOC=1 jekyll build` to skip building and copying over the scaladoc which can be timely. To use the `jekyll` command, you will need to have Jekyll installed, the easiest way to do this is via a Ruby Gem, see the [jekyll installation instructions](http://jekyllrb.com/docs/installation). This will create a directory called _site containing index.html as well as the rest of the compiled files. Read more about Jekyll at https://github.com/mojombo/jekyll/wiki. - -In addition to generating the site as html from the markdown files, jekyll can serve up the site via a webserver. To build and run a local webserver use the command `jekyll serve` (or the faster variant `SKIP_SCALADOC=1 jekyll serve`), which runs the webserver on port 4000, then visit the site at http://localhost:4000. +The markdown code can be compiled to HTML using the +[Jekyll tool](http://jekyllrb.com). +To use the `jekyll` command, you will need to have Jekyll installed. +The easiest way to do this is via a Ruby Gem, see the +[jekyll installation instructions](http://jekyllrb.com/docs/installation). +Compiling the site with Jekyll will create a directory called +_site containing index.html as well as the rest of the compiled files. + +You can modify the default Jekyll build as follows: + + # Skip generating API docs (which takes a while) + $ SKIP_SCALADOC=1 jekyll build + # Serve content locally on port 4000 + $ jekyll serve --watch + # Build the site with extra features used on the live page + $ PRODUCTION=1 jekyll build ## Pygments diff --git a/docs/_config.yml b/docs/_config.yml index 9e5a95fe53af6..aa5a5adbc1743 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -3,10 +3,10 @@ markdown: kramdown # These allow the documentation to be updated with nerw releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.0.0-incubating-SNAPSHOT +SPARK_VERSION: 1.0.0-SNAPSHOT SPARK_VERSION_SHORT: 1.0.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.3" MESOS_VERSION: 0.13.0 SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net -SPARK_GITHUB_URL: https://github.com/apache/incubator-spark +SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 7114e1f5dd5b9..49fd78ca98655 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -24,9 +24,9 @@ + {% production %} + {% endproduction %} @@ -159,16 +159,6 @@

    Heading


    --> -
    -
    -

    - Apache Spark is an effort undergoing incubation at the Apache Software Foundation. - - - -

    -
    - diff --git a/docs/_plugins/production_tag.rb b/docs/_plugins/production_tag.rb new file mode 100644 index 0000000000000..9f870cf2137af --- /dev/null +++ b/docs/_plugins/production_tag.rb @@ -0,0 +1,14 @@ +module Jekyll + class ProductionTag < Liquid::Block + + def initialize(tag_name, markup, tokens) + super + end + + def render(context) + if ENV['PRODUCTION'] then super else "" end + end + end +end + +Liquid::Template.register_tag('production', Jekyll::ProductionTag) diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md index b070d8e73a38b..da6d0c9dcd97b 100644 --- a/docs/bagel-programming-guide.md +++ b/docs/bagel-programming-guide.md @@ -108,7 +108,7 @@ _Example_ ## Operations -Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/incubator-spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details. +Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details. ### Actions diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index ded12926885b9..730a6e7932564 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -25,6 +25,8 @@ If you don't run this, you may see errors like the following: You can fix this by setting the `MAVEN_OPTS` variable as discussed before. +*Note: For Java 1.8 and above this step is not required.* + ## Specifying the Hadoop version ## Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 1.0.4 by default. @@ -54,7 +56,7 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o The ScalaTest plugin also supports running only a specific test suite as follows: - $ mvn -Dhadoop.version=... -Dsuites=spark.repl.ReplSuite test + $ mvn -Dhadoop.version=... -Dsuites=org.apache.spark.repl.ReplSuite test ## Continuous Compilation ## @@ -76,3 +78,19 @@ The maven build includes support for building a Debian package containing the as $ mvn -Pdeb -DskipTests clean package The debian package can then be found under assembly/target. We added the short commit hash to the file name so that we can distinguish individual packages built for SNAPSHOT versions. + +## Running java 8 test suites. + +Running only java 8 tests and nothing else. + + $ mvn install -DskipTests -Pjava8-tests + +Java 8 tests are run when -Pjava8-tests profile is enabled, they will run in spite of -DskipTests. +For these tests to run your system must have a JDK 8 installation. +If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. + +## Packaging without Hadoop dependencies for deployment on YARN ## + +The assembly jar produced by "mvn package" will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The "hadoop-provided" profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. + + diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index e16703292cc22..a555a7b5023e3 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -13,7 +13,7 @@ object in your main program (called the _driver program_). Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_ (either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are -worker processes that run computations and store data for your application. +processes that run computations and store data for your application. Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to the executors. Finally, SparkContext sends *tasks* for the executors to run. diff --git a/docs/configuration.md b/docs/configuration.md index 8e4c48c81f8be..a006224d5080c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -147,6 +147,34 @@ Apart from these, the following properties are also available, and may be useful How many stages the Spark UI remembers before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark web ui. The filter should be a + standard javax servlet Filter. Parameters to each filter can also be specified by setting a + java system property of spark.<class name of filter>.params='param1=value1,param2=value2' + (e.g.-Dspark.ui.filters=com.test.filter1 -Dspark.com.test.filter1.params='param1=foo,param2=testing') + + + + spark.ui.acls.enable + false + + Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has + access permissions to view the web ui. See spark.ui.view.acls for more details. + Also note this requires the user to be known, if the user comes across as null no checks + are done. Filters can be used to authenticate and set the user. + + + + spark.ui.view.acls + Empty + + Comma separated list of users that have view access to the spark web ui. By default only the + user that started the Spark job has view access. + + spark.shuffle.compress true @@ -201,6 +229,13 @@ Apart from these, the following properties are also available, and may be useful multi-user services. + + spark.scheduler.revive.interval + 1000 + + The interval length for the scheduler to revive the worker resource offers to run tasks. (in milliseconds) + + spark.reducer.maxMbInFlight 48 @@ -237,6 +272,17 @@ Apart from these, the following properties are also available, and may be useful exceeded" exception inside Kryo. Note that there will be one buffer per core on each worker. + + spark.serializer.objectStreamReset + 10000 + + When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches + objects to prevent writing redundant data, however that stops garbage collection of those + objects. By calling 'reset' you flush that info from the serializer, and allow old + objects to be collected. To turn off this periodic reset set it to a value of <= 0. + By default it will reset the serializer every 10,000 objects. + + spark.broadcast.factory org.apache.spark.broadcast.
    HttpBroadcastFactory @@ -469,7 +515,7 @@ Apart from these, the following properties are also available, and may be useful the whole cluster by default.
    Note: this setting needs to be configured in the standalone cluster master, not in individual applications; you can set it through SPARK_JAVA_OPTS in spark-env.sh. - + spark.files.overwrite @@ -478,6 +524,38 @@ Apart from these, the following properties are also available, and may be useful Whether to overwrite files added through SparkContext.addFile() when the target file exists and its contents do not match those of the source. + + spark.files.fetchTimeout + false + + Communication timeout to use when fetching files added through SparkContext.addFile() from + the driver. + + + + spark.authenticate + false + + Whether spark authenticates its internal connections. See spark.authenticate.secret if not + running on Yarn. + + + + spark.authenticate.secret + None + + Set the secret key used for Spark to authenticate between components. This needs to be set if + not running on Yarn and authentication is enabled. + + + + spark.core.connection.auth.wait.timeout + 30 + + Number of seconds for the connection to wait for authentication to occur before timing + out and giving up. + + ## Viewing Spark Properties diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 3dfed7bea9ea8..1238e3e0a4e7d 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -135,7 +135,7 @@ Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Chan structure of the graph are accomplished by producing a new graph with the desired changes. Note that substantial parts of the original graph (i.e., unaffected structure, attributes, and indicies) are reused in the new graph reducing the cost of this inherently functional data-structure. The -graph is partitioned across the workers using a range of vertex-partitioning heuristics. As with +graph is partitioned across the executors using a range of vertex-partitioning heuristics. As with RDDs, each partition of the graph can be recreated on a different machine in the event of a failure. Logically the property graph corresponds to a pair of typed collections (RDDs) encoding the diff --git a/docs/index.md b/docs/index.md index aa9c8666e7d75..23311101e1712 100644 --- a/docs/index.md +++ b/docs/index.md @@ -9,7 +9,7 @@ It also supports a rich set of higher-level tools including [Shark](http://shark # Downloading -Get Spark by visiting the [downloads page](http://spark.incubator.apache.org/downloads.html) of the Apache Spark site. This documentation is for Spark version {{site.SPARK_VERSION}}. +Get Spark by visiting the [downloads page](http://spark.apache.org/downloads.html) of the Apache Spark site. This documentation is for Spark version {{site.SPARK_VERSION}}. Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). All you need to run it is to have `java` to installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. @@ -23,10 +23,12 @@ For its Scala API, Spark {{site.SPARK_VERSION}} depends on Scala {{site.SCALA_BI # Running the Examples and Shell -Spark comes with several sample programs in the `examples` directory. -To run one of the samples, use `./bin/run-example ` in the top-level Spark directory +Spark comes with several sample programs. Scala and Java examples are in the `examples` directory, and Python examples are in `python/examples`. +To run one of the Java or Scala sample programs, use `./bin/run-example ` in the top-level Spark directory (the `bin/run-example` script sets up the appropriate paths and launches that program). For example, try `./bin/run-example org.apache.spark.examples.SparkPi local`. +To run a Python sample program, use `./bin/pyspark `. For example, try `./bin/pyspark ./python/examples/pi.py local`. + Each example prints usage help when run with no parameters. Note that all of the sample programs take a `` parameter specifying the cluster URL @@ -96,13 +98,14 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui * [Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using - [Apache Mesos](http://incubator.apache.org/mesos) + [Apache Mesos](http://mesos.apache.org) * [YARN](running-on-yarn.html): deploy Spark on top of Hadoop NextGen (YARN) **Other documents:** * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use +* [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * [Job Scheduling](job-scheduling.html): scheduling resources across and within Spark applications * [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system @@ -110,20 +113,20 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui **External resources:** -* [Spark Homepage](http://spark.incubator.apache.org) +* [Spark Homepage](http://spark.apache.org) * [Shark](http://shark.cs.berkeley.edu): Apache Hive over Spark -* [Mailing Lists](http://spark.incubator.apache.org/mailing-lists.html): ask questions about Spark here +* [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Shark, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/agenda-2012), [slides](http://ampcamp.berkeley.edu/agenda-2012) and [exercises](http://ampcamp.berkeley.edu/exercises-2012) are available online for free. -* [Code Examples](http://spark.incubator.apache.org/examples.html): more are also available in the [examples subfolder](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/) of Spark +* [Code Examples](http://spark.apache.org/examples.html): more are also available in the [examples subfolder](https://github.com/apache/spark/tree/master/examples/src/main/scala/) of Spark * [Paper Describing Spark](http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf) * [Paper Describing Spark Streaming](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) # Community -To get help using Spark or keep up with Spark development, sign up for the [user mailing list](http://spark.incubator.apache.org/mailing-lists.html). +To get help using Spark or keep up with Spark development, sign up for the [user mailing list](http://spark.apache.org/mailing-lists.html). If you're in the San Francisco Bay Area, there's a regular [Spark meetup](http://www.meetup.com/spark-users/) every few weeks. Come by to meet the developers and other users. diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md index 07732fa1229f3..6632360f6e3ca 100644 --- a/docs/java-programming-guide.md +++ b/docs/java-programming-guide.md @@ -21,15 +21,21 @@ operations (e.g. map) and handling RDDs of different types, as discussed next. There are a few key differences between the Java and Scala APIs: -* Java does not support anonymous or first-class functions, so functions must - be implemented by extending the +* Java does not support anonymous or first-class functions, so functions are passed + using anonymous classes that implement the [`org.apache.spark.api.java.function.Function`](api/core/index.html#org.apache.spark.api.java.function.Function), [`Function2`](api/core/index.html#org.apache.spark.api.java.function.Function2), etc. - classes. + interfaces. * To maintain type safety, the Java API defines specialized Function and RDD classes for key-value pairs and doubles. For example, [`JavaPairRDD`](api/core/index.html#org.apache.spark.api.java.JavaPairRDD) stores key-value pairs. +* Some methods are defined on the basis of the passed anonymous function's + (a.k.a lambda expression) return type, + for example mapToPair(...) or flatMapToPair returns + [`JavaPairRDD`](api/core/index.html#org.apache.spark.api.java.JavaPairRDD), + similarly mapToDouble and flatMapToDouble returns + [`JavaDoubleRDD`](api/core/index.html#org.apache.spark.api.java.JavaDoubleRDD). * RDD methods like `collect()` and `countByKey()` return Java collections types, such as `java.util.List` and `java.util.Map`. * Key-value pairs, which are simply written as `(key, value)` in Scala, are represented @@ -53,10 +59,10 @@ each specialized RDD class, so filtering a `PairRDD` returns a new `PairRDD`, etc (this acheives the "same-result-type" principle used by the [Scala collections framework](http://docs.scala-lang.org/overviews/core/architecture-of-scala-collections.html)). -## Function Classes +## Function Interfaces -The following table lists the function classes used by the Java API. Each -class has a single abstract method, `call()`, that must be implemented. +The following table lists the function interfaces used by the Java API. Each +interface has a single abstract method, `call()`, that must be implemented. @@ -78,7 +84,6 @@ RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, suc declared in the [org.apache.spark.api.java.StorageLevels](api/core/index.html#org.apache.spark.api.java.StorageLevels) class. To define your own storage level, you can use StorageLevels.create(...). - # Other Features The Java API supports other Spark features, including @@ -86,6 +91,21 @@ The Java API supports other Spark features, including [broadcast variables](scala-programming-guide.html#broadcast-variables), and [caching](scala-programming-guide.html#rdd-persistence). +# Upgrading From Pre-1.0 Versions of Spark + +In version 1.0 of Spark the Java API was refactored to better support Java 8 +lambda expressions. Users upgrading from older versions of Spark should note +the following changes: + +* All `org.apache.spark.api.java.function.*` have been changed from abstract + classes to interfaces. This means that concrete implementations of these + `Function` classes will need to use `implements` rather than `extends`. +* Certain transformation functions now have multiple versions depending + on the return type. In Spark core, the map functions (map, flatMap, + mapPartitons) have type-specific versions, e.g. + [`mapToPair`](api/core/index.html#org.apache.spark.api.java.JavaRDD@mapToPair[K2,V2](f:org.apache.spark.api.java.function.PairFunction[T,K2,V2]):org.apache.spark.api.java.JavaPairRDD[K2,V2]) + and [`mapToDouble`](api/core/index.html#org.apache.spark.api.java.JavaRDD@mapToDouble[R](f:org.apache.spark.api.java.function.DoubleFunction[T]):org.apache.spark.api.java.JavaDoubleRDD). + Spark Streaming also uses the same approach, e.g. [`transformToPair`](api/streaming/index.html#org.apache.spark.streaming.api.java.JavaDStream@transformToPair[K2,V2](transformFunc:org.apache.spark.api.java.function.Function[R,org.apache.spark.api.java.JavaPairRDD[K2,V2]]):org.apache.spark.streaming.api.java.JavaPairDStream[K2,V2]). # Example @@ -127,11 +147,20 @@ class Split extends FlatMapFunction { JavaRDD words = lines.flatMap(new Split()); {% endhighlight %} +Java 8+ users can also write the above `FlatMapFunction` in a more concise way using +a lambda expression: + +{% highlight java %} +JavaRDD words = lines.flatMap(s -> Arrays.asList(s.split(" "))); +{% endhighlight %} + +This lambda syntax can be applied to all anonymous classes in Java 8. + Continuing with the word count example, we map each word to a `(word, 1)` pair: {% highlight java %} import scala.Tuple2; -JavaPairRDD ones = words.map( +JavaPairRDD ones = words.mapToPair( new PairFunction() { public Tuple2 call(String s) { return new Tuple2(s, 1); @@ -140,7 +169,7 @@ JavaPairRDD ones = words.map( ); {% endhighlight %} -Note that `map` was passed a `PairFunction` and +Note that `mapToPair` was passed a `PairFunction` and returned a `JavaPairRDD`. To finish the word count program, we will use `reduceByKey` to count the @@ -164,7 +193,7 @@ possible to chain the RDD transformations, so the word count example could also be written as: {% highlight java %} -JavaPairRDD counts = lines.flatMap( +JavaPairRDD counts = lines.flatMapToPair( ... ).map( ... @@ -180,16 +209,17 @@ just a matter of style. We currently provide documentation for the Java API as Scaladoc, in the [`org.apache.spark.api.java` package](api/core/index.html#org.apache.spark.api.java.package), because -some of the classes are implemented in Scala. The main downside is that the types and function +some of the classes are implemented in Scala. It is important to note that the types and function definitions show Scala syntax (for example, `def reduce(func: Function2[T, T]): T` instead of -`T reduce(Function2 func)`). -We hope to generate documentation with Java-style syntax in the future. +`T reduce(Function2 func)`). In addition, the Scala `trait` modifier is used for Java +interface classes. We hope to generate documentation with Java-style syntax in the future to +avoid these quirks. # Where to Go from Here Spark includes several sample programs using the Java API in -[`examples/src/main/java`](https://github.com/apache/incubator-spark/tree/master/examples/src/main/java/org/apache/spark/examples). You can run them by passing the class name to the +[`examples/src/main/java`](https://github.com/apache/spark/tree/master/examples/src/main/java/org/apache/spark/examples). You can run them by passing the class name to the `bin/run-example` script included in Spark; for example: ./bin/run-example org.apache.spark.examples.JavaWordCount diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index df2faa5e41b18..94604f301dd46 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -39,8 +39,8 @@ Resource allocation can be configured as follows, based on the cluster type: * **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`, and optionally set `spark.cores.max` to limit each application's resource share as in the standalone mode. You should also set `spark.executor.memory` to control the executor memory. -* **YARN:** The `--num-workers` option to the Spark YARN client controls how many workers it will allocate - on the cluster, while `--worker-memory` and `--worker-cores` control the resources per worker. +* **YARN:** The `--num-executors` option to the Spark YARN client controls how many executors it will allocate + on the cluster, while `--executor-memory` and `--executor-cores` control the resources per executor. A second option available on Mesos is _dynamic sharing_ of CPU cores. In this mode, each Spark application still has a fixed and independent memory allocation (set by `spark.executor.memory`), but when the diff --git a/docs/js/main.js b/docs/js/main.js index 102699789a71a..0bd2286cced19 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -1,26 +1,3 @@ - -// From docs.scala-lang.org -function styleCode() { - if (typeof disableStyleCode != "undefined") { - return; - } - $(".codetabs pre code").parent().each(function() { - if (!$(this).hasClass("prettyprint")) { - var lang = $(this).parent().data("lang"); - if (lang == "python") { - lang = "py" - } - if (lang == "bash") { - lang = "bsh" - } - $(this).addClass("prettyprint lang-"+lang+" linenums"); - } - }); - console.log("runningPrettyPrint()") - prettyPrint(); -} - - function codeTabs() { var counter = 0; var langImages = { @@ -97,11 +74,7 @@ function viewSolution() { } -$(document).ready(function() { +$(function() { codeTabs(); viewSolution(); - $('#chapter-toc').toc({exclude: '', context: '.container'}); - $('#chapter-toc').prepend('

    In This Chapter

    '); - makeCollapsable($('#global-toc'), "", "global-toc", "Show Table of Contents"); - //styleCode(); }); diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 18a3e8e075086..d5bd8042ca2ec 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -77,8 +77,8 @@ between the two goals of small loss and small model complexity. **Distributed Datasets.** For all currently implemented optimization methods for classification, the data must be -distributed between the worker machines *by examples*. Every machine holds a consecutive block of -the `$n$` example/label pairs `$(\x_i,y_i)$`. +distributed between processes on the worker machines *by examples*. Machines hold consecutive +blocks of the `$n$` example/label pairs `$(\x_i,y_i)$`. In other words, the input distributed dataset ([RDD](scala-programming-guide.html#resilient-distributed-datasets-rdds)) must be the set of vectors `$\x_i\in\R^d$`. diff --git a/docs/monitoring.md b/docs/monitoring.md index e9b1d2b2f4ffb..15bfb041780da 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -48,11 +48,22 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `ConsoleSink`: Logs metrics information to the console. * `CSVSink`: Exports metrics data to CSV files at regular intervals. -* `GangliaSink`: Sends metrics to a Ganglia node or multicast group. * `JmxSink`: Registers metrics for viewing in a JXM console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. * `GraphiteSink`: Sends metrics to a Graphite node. +Spark also supports a Ganglia sink which is not included in the default build due to +licensing restrictions: + +* `GangliaSink`: Sends metrics to a Ganglia node or multicast group. + +To install the `GangliaSink` you'll need to perform a custom build of Spark. _**Note that +by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed +code in your Spark package**_. For sbt users, set the +`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable +the `-Pspark-ganglia-lgpl` profile. In addition to modifying the cluster's Spark build +user applications will need to link to the `spark-ganglia-lgpl` artifact. + The syntax of the metrics configuration file is defined in an example configuration file, `$SPARK_HOME/conf/metrics.properties.template`. diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 7c5283fb0b6fb..cbe7d820b455e 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -43,9 +43,9 @@ def is_error(line): errors = logData.filter(is_error) {% endhighlight %} -PySpark will automatically ship these functions to workers, along with any objects that they reference. -Instances of classes will be serialized and shipped to workers by PySpark, but classes themselves cannot be automatically distributed to workers. -The [Standalone Use](#standalone-use) section describes how to ship code dependencies to workers. +PySpark will automatically ship these functions to executors, along with any objects that they reference. +Instances of classes will be serialized and shipped to executors by PySpark, but classes themselves cannot be automatically distributed to executors. +The [Standalone Use](#standalone-use) section describes how to ship code dependencies to executors. In addition, PySpark fully supports interactive use---simply run `./bin/pyspark` to launch an interactive shell. @@ -157,7 +157,7 @@ some example applications. # Where to Go from Here -PySpark also includes several sample programs in the [`python/examples` folder](https://github.com/apache/incubator-spark/tree/master/python/examples). +PySpark also includes several sample programs in the [`python/examples` folder](https://github.com/apache/spark/tree/master/python/examples). You can run them by passing the files to `pyspark`; e.g.: ./bin/pyspark python/examples/wordcount.py diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cd4509ede735a..2e9dec4856ee9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -29,7 +29,7 @@ If you want to test out the YARN deployment mode, you can use the current Spark # Configuration -Most of the configs are the same for Spark on YARN as other deploys. See the Configuration page for more information on those. These are configs that are specific to SPARK on YARN. +Most of the configs are the same for Spark on YARN as for other deployment modes. See the Configuration page for more information on those. These are configs that are specific to Spark on YARN. Environment variables: @@ -41,28 +41,29 @@ System Properties: * `spark.yarn.submit.file.replication`, the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives. * `spark.yarn.preserve.staging.files`, set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them. * `spark.yarn.scheduler.heartbeat.interval-ms`, the interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. Default is 5 seconds. -* `spark.yarn.max.worker.failures`, the maximum number of worker failures before failing the application. Default is the number of workers requested times 2 with minimum of 3. +* `spark.yarn.max.executor.failures`, the maximum number of executor failures before failing the application. Default is the number of executors requested times 2 with minimum of 3. # Launching Spark on YARN -Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster. -This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager. +Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to connect to the cluster, write to the dfs, and connect to the YARN ResourceManager. -There are two scheduler mode that can be used to launch spark application on YARN. +There are two scheduler modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. -## Launch spark application by YARN Client with yarn-standalone mode. +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". -The command to launch the YARN Client is as follows: +## Launching a Spark application with yarn-cluster mode. + +The command to launch the Spark application on the cluster is as follows: SPARK_JAR= ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar \ --class \ --args \ - --num-workers \ - --master-class - --master-memory \ - --worker-memory \ - --worker-cores \ + --num-executors \ + --driver-memory \ + --executor-memory \ + --executor-cores \ --name \ --queue \ --addJars \ @@ -82,58 +83,61 @@ For example: ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ --class org.apache.spark.examples.SparkPi \ - --args yarn-standalone \ - --num-workers 3 \ - --master-memory 4g \ - --worker-memory 2g \ - --worker-cores 1 - - # Examine the output (replace $YARN_APP_ID in the following with the "application identifier" output by the previous command) - # (Note: YARN_APP_LOGS_DIR is usually /tmp/logs or $HADOOP_HOME/logs/userlogs depending on the Hadoop version.) - $ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout - Pi is roughly 3.13794 + --args yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 -The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs. -With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell. +Because the application is run on a remote machine where the Application Master is running, applications that involve local interaction, such as spark-shell, will not work. -## Launch spark application with yarn-client mode. +## Launching a Spark application with yarn-client mode. -With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR +With yarn-client mode, the application will be launched locally, just like running an application or spark-shell on Local / Mesos / Standalone client mode. The launch method is also the same, just make sure to specify the master URL as "yarn-client". You also need to export the env value for SPARK_JAR. Configuration in yarn-client mode: -In order to tune worker core/number/memory etc. You need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options. +In order to tune executor cores/number/memory etc., you need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options. -* `SPARK_YARN_APP_JAR`, Path to your application's JAR file (required) -* `SPARK_WORKER_INSTANCES`, Number of workers to start (Default: 2) -* `SPARK_WORKER_CORES`, Number of cores for the workers (Default: 1). -* `SPARK_WORKER_MEMORY`, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) -* `SPARK_MASTER_MEMORY`, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +* `SPARK_EXECUTOR_INSTANCES`, Number of executors to start (Default: 2) +* `SPARK_EXECUTOR_CORES`, Number of cores per executor (Default: 1). +* `SPARK_EXECUTOR_MEMORY`, Memory per executor (e.g. 1000M, 2G) (Default: 1G) +* `SPARK_DRIVER_MEMORY`, Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb) * `SPARK_YARN_APP_NAME`, The name of your application (Default: Spark) -* `SPARK_YARN_QUEUE`, The hadoop queue to use for allocation requests (Default: 'default') +* `SPARK_YARN_QUEUE`, The YARN queue to use for allocation requests (Default: 'default') * `SPARK_YARN_DIST_FILES`, Comma separated list of files to be distributed with the job. * `SPARK_YARN_DIST_ARCHIVES`, Comma separated list of archives to be distributed with the job. For example: SPARK_JAR=./assembly/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ - SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ ./bin/run-example org.apache.spark.examples.SparkPi yarn-client +or SPARK_JAR=./assembly/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ - SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ MASTER=yarn-client ./bin/spark-shell +## Viewing logs + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the yarn.log-aggregation-enable config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. + +When log aggregation isn't turned on, logs are retained locally on each machine under YARN_APP_LOGS_DIR, which is usually configured to /tmp/logs or $HADOOP_HOME/logs/userlogs depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. + # Building Spark for Hadoop/YARN 2.2.x -See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using the Maven process. +See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using Maven. -# Important Notes +# Important notes - Before Hadoop 2.2, YARN does not support cores in container resource requests. Thus, when running against an earlier version, the numbers of cores given via command line arguments cannot be passed to YARN. Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. -- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored. -- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN. +- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored. +- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt, and your application should use the name as appSees.txt to reference it when running on YARN. - The --addJars option allows the SparkContext.addJar function to work if you are using it with local files. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 506d3faa767f3..99412733d4268 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -365,7 +365,7 @@ res2: Int = 10 # Where to Go from Here -You can see some [example Spark programs](http://spark.incubator.apache.org/examples.html) on the Spark website. +You can see some [example Spark programs](http://spark.apache.org/examples.html) on the Spark website. In addition, Spark includes several samples in `examples/src/main/scala`. Some of them have both Spark versions and local (non-parallel) versions, allowing you to see what had to be changed to make the program run on a cluster. You can run them using by passing the class name to the `bin/run-example` script included in Spark; for example: ./bin/run-example org.apache.spark.examples.SparkPi diff --git a/docs/security.md b/docs/security.md new file mode 100644 index 0000000000000..9e4218fbcfe7d --- /dev/null +++ b/docs/security.md @@ -0,0 +1,18 @@ +--- +layout: global +title: Spark Security +--- + +Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. + +The Spark UI can also be secured by using javax servlet filters. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view acls to make sure they are authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' control the behavior of the acls. Note that the person who started the application always has view access to the UI. + +For Spark on Yarn deployments, configuring `spark.authenticate` to true will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. The Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. + +For other types of Spark deployments, the spark config `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. + +IMPORTANT NOTE: The NettyBlockFetcherIterator is not secured so do not use netty for the shuffle is running with authentication on. + +See [Spark Configuration](configuration.html) for more details on the security configs. + +See org.apache.spark.SecurityManager for implementation details about security. diff --git a/docs/spark-debugger.md b/docs/spark-debugger.md index 11c51d5cde7c9..891c2bfa8943d 100644 --- a/docs/spark-debugger.md +++ b/docs/spark-debugger.md @@ -2,7 +2,7 @@ layout: global title: The Spark Debugger --- -**Summary:** The Spark debugger provides replay debugging for deterministic (logic) errors in Spark programs. It's currently in development, but you can try it out in the [arthur branch](https://github.com/apache/incubator-spark/tree/arthur). +**Summary:** The Spark debugger provides replay debugging for deterministic (logic) errors in Spark programs. It's currently in development, but you can try it out in the [arthur branch](https://github.com/apache/spark/tree/arthur). ## Introduction @@ -19,7 +19,7 @@ For deterministic errors, debugging a Spark program is now as easy as debugging ## Approach -As your Spark program runs, the slaves report key events back to the master -- for example, RDD creations, RDD contents, and uncaught exceptions. (A full list of event types is in [EventLogging.scala](https://github.com/apache/incubator-spark/blob/arthur/core/src/main/scala/spark/EventLogging.scala).) The master logs those events, and you can load the event log into the debugger after your program is done running. +As your Spark program runs, the slaves report key events back to the master -- for example, RDD creations, RDD contents, and uncaught exceptions. (A full list of event types is in [EventLogging.scala](https://github.com/apache/spark/blob/arthur/core/src/main/scala/spark/EventLogging.scala).) The master logs those events, and you can load the event log into the debugger after your program is done running. _A note on nondeterminism:_ For fault recovery, Spark requires RDD transformations (for example, the function passed to `RDD.map`) to be deterministic. The Spark debugger also relies on this property, and it can also warn you if your transformation is nondeterministic. This works by checksumming the contents of each RDD and comparing the checksums from the original execution to the checksums after recomputing the RDD in the debugger. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 57e88581616a2..f9904d45013f6 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -58,11 +58,21 @@ do is as follows.
    +First, we import the names of the Spark Streaming classes, and some implicit +conversions from StreamingContext into our environment, to add useful methods to +other classes we need (like DStream). -First, we create a -[StreamingContext](api/streaming/index.html#org.apache.spark.streaming.StreamingContext) object, -which is the main entry point for all streaming -functionality. Besides Spark's configuration, we specify that any DStream will be processed +[StreamingContext](api/streaming/index.html#org.apache.spark.streaming.StreamingContext) is the +main entry point for all streaming functionality. + +{% highlight scala %} +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ +{% endhighlight %} + +Then we create a +[StreamingContext](api/streaming/index.html#org.apache.spark.streaming.StreamingContext) object. +Besides Spark's configuration, we specify that any DStream will be processed in 1 second batches. {% highlight scala %} @@ -98,7 +108,7 @@ val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) // Print a few of the counts to the console -wordCount.print() +wordCounts.print() {% endhighlight %} The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, @@ -178,7 +188,7 @@ JavaPairDStream wordCounts = pairs.reduceByKey( return i1 + i2; } }); -wordCount.print(); // Print a few of the counts to the console +wordCounts.print(); // Print a few of the counts to the console {% endhighlight %} The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, @@ -262,6 +272,24 @@ Time: 1357008430000 ms
    ClassFunction Type
    +If you plan to run the Scala code for Spark Streaming-based use cases in the Spark +shell, you should start the shell with the SparkConfiguration pre-configured to +discard old batches periodically: + +{% highlight bash %} +$ SPARK_JAVA_OPTS=-Dspark.cleaner.ttl=10000 bin/spark-shell +{% endhighlight %} + +... and create your StreamingContext by wrapping the existing interactive shell +SparkContext object, `sc`: + +{% highlight scala %} +val ssc = new StreamingContext(sc, Seconds(1)) +{% endhighlight %} + +When working with the shell, you may also need to send a `^D` to your netcat session +to force the pipeline to print the word counts to the console at the sink. + *************************************************************************************************** # Basics @@ -511,7 +539,7 @@ common ones are as follows. updateStateByKey(func) Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values for the key. This can be - used to maintain arbitrary state data for each ket. + used to maintain arbitrary state data for each key. diff --git a/docs/tuning.md b/docs/tuning.md index 6b010aed618a3..093df3187a789 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -44,7 +44,10 @@ This setting configures the serializer used for not only shuffling data between nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom registration requirement, but we recommend trying it in any network-intensive application. -Finally, to register your classes with Kryo, create a public class that extends +Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered +in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library. + +To register your own custom classes with Kryo, create a public class that extends [`org.apache.spark.serializer.KryoRegistrator`](api/core/index.html#org.apache.spark.serializer.KryoRegistrator) and set the `spark.kryo.registrator` config property to point to it, as follows: @@ -72,8 +75,8 @@ If your objects are large, you may also need to increase the `spark.kryoserializ config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. -Finally, if you don't register your classes, Kryo will still work, but it will have to store the -full class name with each object, which is wasteful. +Finally, if you don't register your custom classes, Kryo will still work, but it will have to store +the full class name with each object, which is wasteful. # Memory Tuning @@ -160,8 +163,8 @@ their work directories), *not* on your driver program. **Cache Size Tuning** One important configuration parameter for GC is the amount of memory that should be used for caching RDDs. -By default, Spark uses 66% of the configured executor memory (`spark.executor.memory` or `SPARK_MEM`) to -cache RDDs. This means that 33% of memory is available for any objects created during task execution. +By default, Spark uses 60% of the configured executor memory (`spark.executor.memory`) to +cache RDDs. This means that 40% of memory is available for any objects created during task execution. In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call diff --git a/ec2/README b/ec2/README index 433da37b4c37c..72434f24bf98d 100644 --- a/ec2/README +++ b/ec2/README @@ -1,4 +1,4 @@ This folder contains a script, spark-ec2, for launching Spark clusters on Amazon EC2. Usage instructions are available online at: -http://spark.incubator.apache.org/docs/latest/ec2-scripts.html +http://spark.apache.org/docs/latest/ec2-scripts.html diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index b0512ca891ad6..d8840c94ac17c 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -73,7 +73,7 @@ def parse_args(): parser.add_option("-v", "--spark-version", default="0.9.0", help="Version of Spark to use: 'X.Y.Z' or a specific git hash") parser.add_option("--spark-git-repo", - default="https://github.com/apache/incubator-spark", + default="https://github.com/apache/spark", help="Github repo from which to checkout supplied commit hash") parser.add_option("--hadoop-major-version", default="1", help="Major version of Hadoop (default: 1)") @@ -398,15 +398,13 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): if any((master_nodes, slave_nodes)): print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) - if (master_nodes != [] and slave_nodes != []) or not die_on_error: + if master_nodes != [] or not die_on_error: return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print "ERROR: Could not find master in group " + cluster_name + "-master" - elif master_nodes != [] and slave_nodes == []: - print "ERROR: Could not find slaves in group " + cluster_name + "-slaves" + print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" else: - print "ERROR: Could not find any existing cluster" + print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) @@ -680,6 +678,9 @@ def real_main(): opts.zone = random.choice(conn.get_all_zones()).name if action == "launch": + if opts.slaves <= 0: + print >> sys.stderr, "ERROR: You have to start at least 1 slave" + sys.exit(1) if opts.resume: (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name) diff --git a/examples/pom.xml b/examples/pom.xml index 12a11821a4947..382a38d9400b9 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml @@ -29,22 +29,21 @@ spark-examples_2.10 jar Spark Project Examples - http://spark.incubator.apache.org/ - - - - apache-repo - Apache Repository - https://repository.apache.org/content/repositories/releases - - true - - - false - - - + http://spark.apache.org/ + + + + yarn-alpha + + + org.apache.avro + avro + + + + @@ -169,14 +168,6 @@ org.apache.cassandra.deps avro - - org.sonatype.sisu.inject - * - - - org.xerial.snappy - * - diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index d552c47b22231..6b49244ba459d 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -45,7 +45,7 @@ static class DataPoint implements Serializable { double y; } - static class ParsePoint extends Function { + static class ParsePoint implements Function { private static final Pattern SPACE = Pattern.compile(" "); @Override @@ -60,7 +60,7 @@ public DataPoint call(String line) { } } - static class VectorSum extends Function2 { + static class VectorSum implements Function2 { @Override public double[] call(double[] a, double[] b) { double[] result = new double[D]; @@ -71,7 +71,7 @@ public double[] call(double[] a, double[] b) { } } - static class ComputeGradient extends Function { + static class ComputeGradient implements Function { private final double[] weights; ComputeGradient(double[] weights) { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java index 0dc879275a22a..2d797279d5bcc 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java @@ -98,7 +98,7 @@ public Vector call(String line) { double tempDist; do { // allocate each vector to closest centroid - JavaPairRDD closest = data.map( + JavaPairRDD closest = data.mapToPair( new PairFunction() { @Override public Tuple2 call(Vector vector) { diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 9eb1cadd71d22..617e4a6d045e0 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -108,7 +108,7 @@ public static void main(String[] args) { JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs); - JavaPairRDD, Stats> extracted = dataSet.map(new PairFunction, Stats>() { + JavaPairRDD, Stats> extracted = dataSet.mapToPair(new PairFunction, Stats>() { @Override public Tuple2, Stats> call(String s) { return new Tuple2, Stats>(extractKey(s), extractStats(s)); @@ -124,7 +124,7 @@ public Stats call(Stats stats, Stats stats2) { List, Stats>> output = counts.collect(); for (Tuple2 t : output) { - System.out.println(t._1 + "\t" + t._2); + System.out.println(t._1() + "\t" + t._2()); } System.exit(0); } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index a84245b0c7449..eb70fb547564c 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -42,7 +42,7 @@ public final class JavaPageRank { private static final Pattern SPACES = Pattern.compile("\\s+"); - private static class Sum extends Function2 { + private static class Sum implements Function2 { @Override public Double call(Double a, Double b) { return a + b; @@ -66,7 +66,7 @@ public static void main(String[] args) throws Exception { JavaRDD lines = ctx.textFile(args[1], 1); // Loads all URLs from input file and initialize their neighbors. - JavaPairRDD> links = lines.map(new PairFunction() { + JavaPairRDD> links = lines.mapToPair(new PairFunction() { @Override public Tuple2 call(String s) { String[] parts = SPACES.split(s); @@ -86,12 +86,12 @@ public Double call(List rs) { for (int current = 0; current < Integer.parseInt(args[2]); current++) { // Calculates URL contributions to the rank of other URLs. JavaPairRDD contribs = links.join(ranks).values() - .flatMap(new PairFlatMapFunction, Double>, String, Double>() { + .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { @Override public Iterable> call(Tuple2, Double> s) { List> results = new ArrayList>(); - for (String n : s._1) { - results.add(new Tuple2(n, s._2 / s._1.size())); + for (String n : s._1()) { + results.add(new Tuple2(n, s._2() / s._1().size())); } return results; } @@ -109,7 +109,7 @@ public Double call(Double sum) { // Collects all URL ranks and dump them to console. List> output = ranks.collect(); for (Tuple2 tuple : output) { - System.out.println(tuple._1 + " has rank: " + tuple._2 + "."); + System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); } System.exit(0); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index 2ceb0fd94ba65..6cfe25c80ecc6 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -50,7 +50,7 @@ static List> generateGraph() { return new ArrayList>(edges); } - static class ProjectFn extends PairFunction>, + static class ProjectFn implements PairFunction>, Integer, Integer> { static final ProjectFn INSTANCE = new ProjectFn(); @@ -77,7 +77,7 @@ public static void main(String[] args) { // the graph to obtain the path (x, z). // Because join() joins on keys, the edges are stored in reversed order. - JavaPairRDD edges = tc.map( + JavaPairRDD edges = tc.mapToPair( new PairFunction, Integer, Integer>() { @Override public Tuple2 call(Tuple2 e) { @@ -91,7 +91,7 @@ public Tuple2 call(Tuple2 e) { oldCount = nextCount; // Perform the join, obtaining an RDD of (y, (z, x)) pairs, // then project the result to obtain the new (x, z) paths. - tc = tc.union(tc.join(edges).map(ProjectFn.INSTANCE)).distinct().cache(); + tc = tc.union(tc.join(edges).mapToPair(ProjectFn.INSTANCE)).distinct().cache(); nextCount = tc.count(); } while (nextCount != oldCount); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 6651f98d56711..3ae1d8f7ca938 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -49,7 +49,7 @@ public Iterable call(String s) { } }); - JavaPairRDD ones = words.map(new PairFunction() { + JavaPairRDD ones = words.mapToPair(new PairFunction() { @Override public Tuple2 call(String s) { return new Tuple2(s, 1); @@ -65,7 +65,7 @@ public Integer call(Integer i1, Integer i2) { List> output = counts.collect(); for (Tuple2 tuple : output) { - System.out.println(tuple._1 + ": " + tuple._2); + System.out.println(tuple._1() + ": " + tuple._2()); } System.exit(0); } diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java index 435a86e62abc5..64a3a04fb7296 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java @@ -35,7 +35,7 @@ */ public final class JavaALS { - static class ParseRating extends Function { + static class ParseRating implements Function { private static final Pattern COMMA = Pattern.compile(","); @Override @@ -48,7 +48,7 @@ public Rating call(String line) { } } - static class FeaturesToString extends Function, String> { + static class FeaturesToString implements Function, String> { @Override public String call(Tuple2 element) { return element._1() + "," + Arrays.toString(element._2()); diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java index 4b2658f257b3c..76ebdccfd6b67 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java @@ -32,7 +32,7 @@ */ public final class JavaKMeans { - static class ParsePoint extends Function { + static class ParsePoint implements Function { private static final Pattern SPACE = Pattern.compile(" "); @Override diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java index 21586ce817d09..667c72f379e71 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java @@ -34,7 +34,7 @@ */ public final class JavaLR { - static class ParsePoint extends Function { + static class ParsePoint implements Function { private static final Pattern COMMA = Pattern.compile(","); private static final Pattern SPACE = Pattern.compile(" "); diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java index 2ffd351b4e498..d704be08d6945 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java @@ -89,7 +89,7 @@ public Iterable call(String x) { } }); - JavaPairDStream wordCounts = words.map( + JavaPairDStream wordCounts = words.mapToPair( new PairFunction() { @Override public Tuple2 call(String s) { diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java index 7777c9832abd3..7f68d451e9b31 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java @@ -69,7 +69,7 @@ public Iterable call(String x) { return Lists.newArrayList(SPACE.split(x)); } }); - JavaPairDStream wordCounts = words.map( + JavaPairDStream wordCounts = words.mapToPair( new PairFunction() { @Override public Tuple2 call(String s) { diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java index 26c44620abec1..88ad341641e0a 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java @@ -63,7 +63,7 @@ public static void main(String[] args) throws Exception { // Create the QueueInputDStream and use it do some processing JavaDStream inputStream = ssc.queueStream(rddQueue); - JavaPairDStream mappedStream = inputStream.map( + JavaPairDStream mappedStream = inputStream.mapToPair( new PairFunction() { @Override public Tuple2 call(Integer i) { diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala new file mode 100644 index 0000000000000..ee283ce6abac2 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.nio.ByteBuffer +import scala.collection.JavaConversions._ +import scala.collection.mutable.ListBuffer +import scala.collection.immutable.Map +import org.apache.cassandra.hadoop.ConfigHelper +import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat +import org.apache.cassandra.hadoop.cql3.CqlConfigHelper +import org.apache.cassandra.hadoop.cql3.CqlOutputFormat +import org.apache.cassandra.utils.ByteBufferUtil +import org.apache.hadoop.mapreduce.Job +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +/* + Need to create following keyspace and column family in cassandra before running this example + Start CQL shell using ./bin/cqlsh and execute following commands + CREATE KEYSPACE retail WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; + use retail; + CREATE TABLE salecount (prod_id text, sale_count int, PRIMARY KEY (prod_id)); + CREATE TABLE ordercf (user_id text, + time timestamp, + prod_id text, + quantity int, + PRIMARY KEY (user_id, time)); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('bob', 1385983646000, 'iphone', 1); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('tom', 1385983647000, 'samsung', 4); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('dora', 1385983648000, 'nokia', 2); + INSERT INTO ordercf (user_id, + time, + prod_id, + quantity) VALUES ('charlie', 1385983649000, 'iphone', 2); +*/ + +/** + * This example demonstrates how to read and write to cassandra column family created using CQL3 + * using Spark. + * Parameters : + * Usage: ./bin/run-example org.apache.spark.examples.CassandraCQLTest local[2] localhost 9160 + * + */ +object CassandraCQLTest { + + def main(args: Array[String]) { + val sc = new SparkContext(args(0), + "CQLTestApp", + System.getenv("SPARK_HOME"), + SparkContext.jarOfClass(this.getClass)) + val cHost: String = args(1) + val cPort: String = args(2) + val KeySpace = "retail" + val InputColumnFamily = "ordercf" + val OutputColumnFamily = "salecount" + + val job = new Job() + job.setInputFormatClass(classOf[CqlPagingInputFormat]) + ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) + ConfigHelper.setInputRpcPort(job.getConfiguration(), cPort) + ConfigHelper.setInputColumnFamily(job.getConfiguration(), KeySpace, InputColumnFamily) + ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + CqlConfigHelper.setInputCQLPageRowSize(job.getConfiguration(), "3") + + /** CqlConfigHelper.setInputWhereClauses(job.getConfiguration(), "user_id='bob'") */ + + /** An UPDATE writes one or more columns to a record in a Cassandra column family */ + val query = "UPDATE " + KeySpace + "." + OutputColumnFamily + " SET sale_count = ? " + CqlConfigHelper.setOutputCql(job.getConfiguration(), query) + + job.setOutputFormatClass(classOf[CqlOutputFormat]) + ConfigHelper.setOutputColumnFamily(job.getConfiguration(), KeySpace, OutputColumnFamily) + ConfigHelper.setOutputInitialAddress(job.getConfiguration(), cHost) + ConfigHelper.setOutputRpcPort(job.getConfiguration(), cPort) + ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner") + + val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), + classOf[CqlPagingInputFormat], + classOf[java.util.Map[String,ByteBuffer]], + classOf[java.util.Map[String,ByteBuffer]]) + + println("Count: " + casRdd.count) + val productSaleRDD = casRdd.map { + case (key, value) => { + (ByteBufferUtil.string(value.get("prod_id")), ByteBufferUtil.toInt(value.get("quantity"))) + } + } + val aggregatedRDD = productSaleRDD.reduceByKey(_ + _) + aggregatedRDD.collect().foreach { + case (productId, saleCount) => println(productId + ":" + saleCount) + } + + val casoutputCF = aggregatedRDD.map { + case (productId, saleCount) => { + val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) + val outKey: java.util.Map[String, ByteBuffer] = outColFamKey + var outColFamVal = new ListBuffer[ByteBuffer] + outColFamVal += ByteBufferUtil.bytes(saleCount) + val outVal: java.util.List[ByteBuffer] = outColFamVal + (outKey, outVal) + } + } + + casoutputCF.saveAsNewAPIHadoopFile( + KeySpace, + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.List[ByteBuffer]], + classOf[CqlOutputFormat], + job.getConfiguration() + ) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 17bafc2218a31..ce4b3c8451e00 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -128,11 +128,11 @@ object SparkALS { println("Iteration " + iter + ":") ms = sc.parallelize(0 until M, slices) .map(i => update(i, msb.value(i), usb.value, Rc.value)) - .toArray + .collect() msb = sc.broadcast(ms) // Re-broadcast ms because it was updated us = sc.parallelize(0 until U, slices) .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value))) - .toArray + .collect() usb = sc.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) println() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparkSVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparkSVD.scala index 19676fcc1a2b0..ce2b133368e85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparkSVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparkSVD.scala @@ -54,6 +54,6 @@ object SparkSVD { val s = decomposed.S.data val v = decomposed.V.data - println("singular values = " + s.toArray.mkString) + println("singular values = " + s.collect().mkString) } } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index 3d7b390724e77..62d3a52615584 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -23,7 +23,7 @@ import scala.util.Random import akka.actor.{Actor, ActorRef, Props, actorRef2Scala} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions import org.apache.spark.streaming.receivers.Receiver @@ -112,8 +112,9 @@ object FeederActor { } val Seq(host, port) = args.toSeq - - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = new SparkConf)._1 + val conf = new SparkConf + val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = conf, + securityManager = new SecurityManager(conf))._1 val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") println("Feeder started as:" + feeder) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index a0e8b84514ef6..f21963531574b 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../../pom.xml @@ -29,7 +29,21 @@ spark-streaming-flume_2.10 jar Spark Project External Flume - http://spark.incubator.apache.org/ + http://spark.apache.org/ + + + + + yarn-alpha + + + org.apache.avro + avro + + + + @@ -53,10 +67,6 @@ org.jboss.netty netty - - org.xerial.snappy - * - diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index fb37cd79884c8..343e1fabd823f 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../../pom.xml @@ -29,7 +29,21 @@ spark-streaming-kafka_2.10 jar Spark Project External Kafka - http://spark.incubator.apache.org/ + http://spark.apache.org/ + + + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index cfa1870e982fe..3710a63541d78 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../../pom.xml @@ -29,21 +29,21 @@ spark-streaming-mqtt_2.10 jar Spark Project External MQTT - http://spark.incubator.apache.org/ + http://spark.apache.org/ - - - mqtt-repo - MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases - - true - - - false - - - + + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 077f88dc59bab..398b9f4fbaa7d 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../../pom.xml @@ -29,7 +29,21 @@ spark-streaming-twitter_2.10 jar Spark Project External Twitter - http://spark.incubator.apache.org/ + http://spark.apache.org/ + + + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 4c68294b7b5af..77e957f404645 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../../pom.xml @@ -29,7 +29,21 @@ spark-streaming-zeromq_2.10 jar Spark Project External ZeroMQ - http://spark.incubator.apache.org/ + http://spark.apache.org/ + + + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index c989ec0f27465..b254e00714621 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -75,7 +75,7 @@ object ZeroMQUtils { ): JavaDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.apply(x.map(_.toArray).toArray).toIterator + val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel, supervisorStrategy) } @@ -99,7 +99,7 @@ object ZeroMQUtils { ): JavaDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.apply(x.map(_.toArray).toArray).toIterator + val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator createStream[T](jssc.ssc, publisherUrl, subscribe, fn, storageLevel) } @@ -122,7 +122,7 @@ object ZeroMQUtils { ): JavaDStream[T] = { implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] - val fn = (x: Seq[ByteString]) => bytesToObjects.apply(x.map(_.toArray).toArray).toIterator + val fn = (x: Seq[ByteString]) => bytesToObjects.call(x.map(_.toArray).toArray).toIterator createStream[T](jssc.ssc, publisherUrl, subscribe, fn) } } diff --git a/extras/README.md b/extras/README.md new file mode 100644 index 0000000000000..1b4174b7d5cff --- /dev/null +++ b/extras/README.md @@ -0,0 +1 @@ +This directory contains build components not included by default in Spark's build. diff --git a/extras/java8-tests/README.md b/extras/java8-tests/README.md new file mode 100644 index 0000000000000..e95b73ac7702a --- /dev/null +++ b/extras/java8-tests/README.md @@ -0,0 +1,24 @@ +# Java 8 Test Suites + +These tests require having Java 8 installed and are isolated from the main Spark build. +If Java 8 is not your system's default Java version, you will need to point Spark's build +to your Java location. The set-up depends a bit on the build system: + +* Sbt users can either set JAVA_HOME to the location of a Java 8 JDK or explicitly pass + `-java-home` to the sbt launch script. If a Java 8 JDK is detected sbt will automatically + include the Java 8 test project. + + `$ JAVA_HOME=/opt/jdk1.8.0/ sbt/sbt clean "test-only org.apache.spark.Java8APISuite"` + +* For Maven users, + + Maven users can also refer to their Java 8 directory using JAVA_HOME. However, Maven will not + automatically detect the presence of a Java 8 JDK, so a special build profile `-Pjava8-tests` + must be used. + + `$ JAVA_HOME=/opt/jdk1.8.0/ mvn clean install -DskipTests` + `$ JAVA_HOME=/opt/jdk1.8.0/ mvn test -Pjava8-tests -DwildcardSuites=org.apache.spark.Java8APISuite` + + Note that the above command can only be run from project root directory since this module + depends on core and the test-jars of core and streaming. This means an install step is + required to make the test dependencies visible to the Java 8 sub-project. diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml new file mode 100644 index 0000000000000..602f66f9c5cf1 --- /dev/null +++ b/extras/java8-tests/pom.xml @@ -0,0 +1,151 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + java8-tests_2.10 + pom + Spark Project Java8 Tests POM + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + + + com.novocode + junit-interface + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + + java8-tests + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + test + + test + + + + + + + + file:src/test/resources/log4j.properties + + + false + + **/Suite*.java + **/*Suite.java + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + test-compile-first + process-test-resources + + testCompile + + + + + true + true + true + 1.8 + 1.8 + 1.8 + UTF-8 + 1024m + + + + + net.alchim31.maven + scala-maven-plugin + + + none + + + scala-compile-first + none + + + scala-test-compile-first + none + + + attach-scaladocs + none + + + + + org.scalatest + scalatest-maven-plugin + + + test + none + + + + + + diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java new file mode 100644 index 0000000000000..f67251217ed4a --- /dev/null +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.File; +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.io.Files; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.SequenceFileOutputFormat; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; + +/** + * Most of these tests replicate org.apache.spark.JavaAPISuite using java 8 + * lambda syntax. + */ +public class Java8APISuite implements Serializable { + static int foreachCalls = 0; + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port"); + } + + @Test + public void foreachWithAnonymousClass() { + foreachCalls = 0; + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreach(new VoidFunction() { + @Override + public void call(String s) { + foreachCalls++; + } + }); + Assert.assertEquals(2, foreachCalls); + } + + @Test + public void foreach() { + foreachCalls = 0; + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); + rdd.foreach((x) -> foreachCalls++); + Assert.assertEquals(2, foreachCalls); + } + + @Test + public void groupBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function isOdd = x -> x % 2 == 0; + JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens + Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + + oddsAndEvens = rdd.groupBy(isOdd, 1); + Assert.assertEquals(2, oddsAndEvens.count()); + Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens + Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + } + + @Test + public void leftOuterJoin() { + JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList( + new Tuple2(1, 1), + new Tuple2(1, 2), + new Tuple2(2, 1), + new Tuple2(3, 1) + )); + JavaPairRDD rdd2 = sc.parallelizePairs(Arrays.asList( + new Tuple2(1, 'x'), + new Tuple2(2, 'y'), + new Tuple2(2, 'z'), + new Tuple2(4, 'w') + )); + List>>> joined = + rdd1.leftOuterJoin(rdd2).collect(); + Assert.assertEquals(5, joined.size()); + Tuple2>> firstUnmatched = + rdd1.leftOuterJoin(rdd2).filter(tup -> !tup._2()._2().isPresent()).first(); + Assert.assertEquals(3, firstUnmatched._1().intValue()); + } + + @Test + public void foldReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); + Function2 add = (a, b) -> a + b; + + int sum = rdd.fold(0, add); + Assert.assertEquals(33, sum); + + sum = rdd.reduce(add); + Assert.assertEquals(33, sum); + } + + @Test + public void foldByKey() { + List> pairs = Arrays.asList( + new Tuple2(2, 1), + new Tuple2(2, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD sums = rdd.foldByKey(0, (a, b) -> a + b); + Assert.assertEquals(1, sums.lookup(1).get(0).intValue()); + Assert.assertEquals(2, sums.lookup(2).get(0).intValue()); + Assert.assertEquals(3, sums.lookup(3).get(0).intValue()); + } + + @Test + public void reduceByKey() { + List> pairs = Arrays.asList( + new Tuple2(2, 1), + new Tuple2(2, 1), + new Tuple2(1, 1), + new Tuple2(3, 2), + new Tuple2(3, 1) + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + JavaPairRDD counts = rdd.reduceByKey((a, b) -> a + b); + Assert.assertEquals(1, counts.lookup(1).get(0).intValue()); + Assert.assertEquals(2, counts.lookup(2).get(0).intValue()); + Assert.assertEquals(3, counts.lookup(3).get(0).intValue()); + + Map localCounts = counts.collectAsMap(); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + + localCounts = rdd.reduceByKeyLocally((a, b) -> a + b); + Assert.assertEquals(1, localCounts.get(1).intValue()); + Assert.assertEquals(2, localCounts.get(2).intValue()); + Assert.assertEquals(3, localCounts.get(3).intValue()); + } + + @Test + public void map() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.mapToDouble(x -> 1.0 * x).cache(); + doubles.collect(); + JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2(x, x)) + .cache(); + pairs.collect(); + JavaRDD strings = rdd.map(x -> x.toString()).cache(); + strings.collect(); + } + + @Test + public void flatMap() { + JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", + "The quick brown fox jumps over the lazy dog.")); + JavaRDD words = rdd.flatMap(x -> Arrays.asList(x.split(" "))); + + Assert.assertEquals("Hello", words.first()); + Assert.assertEquals(11, words.count()); + + JavaPairRDD pairs = rdd.flatMapToPair(s -> { + List> pairs2 = new LinkedList>(); + for (String word : s.split(" ")) pairs2.add(new Tuple2(word, word)); + return pairs2; + }); + + Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first()); + Assert.assertEquals(11, pairs.count()); + + JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { + List lengths = new LinkedList(); + for (String word : s.split(" ")) lengths.add(word.length() * 1.0); + return lengths; + }); + + Double x = doubles.first(); + Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01); + Assert.assertEquals(11, pairs.count()); + } + + @Test + public void mapsFromPairsToPairs() { + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD swapped = + pairRDD.flatMapToPair(x -> Collections.singletonList(x.swap())); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(item -> item.swap()).collect(); + } + + @Test + public void mapPartitions() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); + JavaRDD partitionSums = rdd.mapPartitions(iter -> { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); + } + return Collections.singletonList(sum); + }); + + Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); + } + + @Test + public void sequenceFile() { + File tempDir = Files.createTempDir(); + String outputDir = new File(tempDir, "output").getAbsolutePath(); + List> pairs = Arrays.asList( + new Tuple2(1, "a"), + new Tuple2(2, "aa"), + new Tuple2(3, "aaa") + ); + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + rdd.mapToPair(pair -> + new Tuple2(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + + // Try reading the output back as an object file + JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, Text.class) + .mapToPair(pair -> new Tuple2(pair._1().get(), pair._2().toString())); + Assert.assertEquals(pairs, readRDD.collect()); + } + + @Test + public void zip() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaDoubleRDD doubles = rdd.mapToDouble(x -> 1.0 * x); + JavaPairRDD zipped = rdd.zip(doubles); + zipped.count(); + } + + @Test + public void zipPartitions() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); + JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); + FlatMapFunction2, Iterator, Integer> sizesFn = + (Iterator i, Iterator s) -> { + int sizeI = 0; + int sizeS = 0; + while (i.hasNext()) { + sizeI += 1; + i.next(); + } + while (s.hasNext()) { + sizeS += 1; + s.next(); + } + return Arrays.asList(sizeI, sizeS); + }; + JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); + Assert.assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); + } + + @Test + public void accumulators() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(x -> intAccum.add(x)); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(x -> doubleAccum.add((double) x)); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam floatAccumulatorParam = new AccumulatorParam() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(x -> floatAccum.add((float) x)); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + + // Test the setValue method + floatAccum.setValue(5.0f); + Assert.assertEquals((Float) 5.0f, floatAccum.value()); + } + + @Test + public void keyBy() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); + List> s = rdd.keyBy(x -> x.toString()).collect(); + Assert.assertEquals(new Tuple2("1", 1), s.get(0)); + Assert.assertEquals(new Tuple2("2", 2), s.get(1)); + } + + @Test + public void mapOnPairRDD() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4)); + JavaPairRDD rdd2 = + rdd1.mapToPair(i -> new Tuple2(i, i % 2)); + JavaPairRDD rdd3 = + rdd2.mapToPair(in -> new Tuple2(in._2(), in._1())); + Assert.assertEquals(Arrays.asList( + new Tuple2(1, 1), + new Tuple2(0, 2), + new Tuple2(1, 3), + new Tuple2(0, 4)), rdd3.collect()); + } + + @Test + public void collectPartitions() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); + + JavaPairRDD rdd2 = + rdd1.mapToPair(i -> new Tuple2(i, i % 2)); + List[] parts = rdd1.collectPartitions(new int[]{0}); + Assert.assertEquals(Arrays.asList(1, 2), parts[0]); + + parts = rdd1.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(3, 4), parts[0]); + Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); + + Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), + new Tuple2(2, 0)), + rdd2.collectPartitions(new int[]{0})[0]); + + parts = rdd2.collectPartitions(new int[]{1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), + new Tuple2(4, 0)), parts[0]); + Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), + new Tuple2(6, 0), + new Tuple2(7, 1)), parts[1]); + } + + @Test + public void collectAsMapWithIntArrayValues() { + // Regression test for SPARK-1040 + JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[]{1})); + JavaPairRDD pairRDD = + rdd.mapToPair(x -> new Tuple2(x, new int[]{x})); + pairRDD.collect(); // Works fine + Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException + } +} diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java new file mode 100644 index 0000000000000..43df0dea614bc --- /dev/null +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -0,0 +1,841 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; + +/** + * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 + * lambda syntax. + */ +public class Java8APISuite extends LocalJavaStreamingContext implements Serializable { + + @Test + public void testMap() { + List> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("goodnight", "moon")); + + List> expected = Arrays.asList( + Arrays.asList(5, 5), + Arrays.asList(9, 4)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream letterCount = stream.map(s -> s.length()); + JavaTestUtils.attachTestOutputStream(letterCount); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("yankees")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream filtered = stream.filter(s -> s.contains("a")); + JavaTestUtils.attachTestOutputStream(filtered); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testMapPartitions() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List> expected = Arrays.asList( + Arrays.asList("GIANTSDODGERS"), + Arrays.asList("YANKEESRED SOCKS")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream mapped = stream.mapPartitions(in -> { + String out = ""; + while (in.hasNext()) { + out = out + in.next().toUpperCase(); + } + return Lists.newArrayList(out); + }); + JavaTestUtils.attachTestOutputStream(mapped); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduce() { + List> inputData = Arrays.asList( + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6), + Arrays.asList(7, 8, 9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(15), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reduced = stream.reduce((x, y) -> x + y); + JavaTestUtils.attachTestOutputStream(reduced); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByWindow() { + List> inputData = Arrays.asList( + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6), + Arrays.asList(7, 8, 9)); + + List> expected = Arrays.asList( + Arrays.asList(6), + Arrays.asList(21), + Arrays.asList(39), + Arrays.asList(24)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream reducedWindowed = stream.reduceByWindow((x, y) -> x + y, + (x, y) -> x - y, new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reducedWindowed); + List> result = JavaTestUtils.runStreams(ssc, 4, 4); + + Assert.assertEquals(expected, result); + } + + @Test + public void testTransform() { + List> inputData = Arrays.asList( + Arrays.asList(1, 2, 3), + Arrays.asList(4, 5, 6), + Arrays.asList(7, 8, 9)); + + List> expected = Arrays.asList( + Arrays.asList(3, 4, 5), + Arrays.asList(6, 7, 8), + Arrays.asList(9, 10, 11)); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream transformed = stream.transform(in -> in.map(i -> i + 2)); + + JavaTestUtils.attachTestOutputStream(transformed); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testVariousTransform() { + // tests whether all variations of transform can be called from Java + + List> inputData = Arrays.asList(Arrays.asList(1)); + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + + List>> pairInputData = + Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1)); + + JavaDStream transformed1 = stream.transform(in -> null); + JavaDStream transformed2 = stream.transform((x, time) -> null); + JavaPairDStream transformed3 = stream.transformToPair(x -> null); + JavaPairDStream transformed4 = stream.transformToPair((x, time) -> null); + JavaDStream pairTransformed1 = pairStream.transform(x -> null); + JavaDStream pairTransformed2 = pairStream.transform((x, time) -> null); + JavaPairDStream pairTransformed3 = pairStream.transformToPair(x -> null); + JavaPairDStream pairTransformed4 = + pairStream.transformToPair((x, time) -> null); + + } + + @Test + public void testTransformWith() { + List>> stringStringKVStream1 = Arrays.asList( + Arrays.asList( + new Tuple2("california", "dodgers"), + new Tuple2("new york", "yankees")), + Arrays.asList( + new Tuple2("california", "sharks"), + new Tuple2("new york", "rangers"))); + + List>> stringStringKVStream2 = Arrays.asList( + Arrays.asList( + new Tuple2("california", "giants"), + new Tuple2("new york", "mets")), + Arrays.asList( + new Tuple2("california", "ducks"), + new Tuple2("new york", "islanders"))); + + + List>>> expected = Arrays.asList( + Sets.newHashSet( + new Tuple2>("california", + new Tuple2("dodgers", "giants")), + new Tuple2>("new york", + new Tuple2("yankees", "mets"))), + Sets.newHashSet( + new Tuple2>("california", + new Tuple2("sharks", "ducks")), + new Tuple2>("new york", + new Tuple2("rangers", "islanders")))); + + JavaDStream> stream1 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream1, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream(stream1); + + JavaDStream> stream2 = JavaTestUtils.attachTestInputStream( + ssc, stringStringKVStream2, 1); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); + + JavaPairDStream> joined = + pairStream1.transformWithToPair(pairStream2,(x, y, z) -> x.join(y)); + + JavaTestUtils.attachTestOutputStream(joined); + List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List>>> unorderedResult = Lists.newArrayList(); + for (List>> res : result) { + unorderedResult.add(Sets.newHashSet(res)); + } + + Assert.assertEquals(expected, unorderedResult); + } + + + @Test + public void testVariousTransformWith() { + // tests whether all variations of transformWith can be called from Java + + List> inputData1 = Arrays.asList(Arrays.asList(1)); + List> inputData2 = Arrays.asList(Arrays.asList("x")); + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 1); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1); + + List>> pairInputData1 = + Arrays.asList(Arrays.asList(new Tuple2("x", 1))); + List>> pairInputData2 = + Arrays.asList(Arrays.asList(new Tuple2(1.0, 'x'))); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1)); + JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1)); + + JavaDStream transformed1 = stream1.transformWith(stream2, (x, y, z) -> null); + JavaDStream transformed2 = stream1.transformWith(pairStream1,(x, y, z) -> null); + + JavaPairDStream transformed3 = + stream1.transformWithToPair(stream2,(x, y, z) -> null); + + JavaPairDStream transformed4 = + stream1.transformWithToPair(pairStream1,(x, y, z) -> null); + + JavaDStream pairTransformed1 = pairStream1.transformWith(stream2,(x, y, z) -> null); + + JavaDStream pairTransformed2_ = + pairStream1.transformWith(pairStream1,(x, y, z) -> null); + + JavaPairDStream pairTransformed3 = + pairStream1.transformWithToPair(stream2,(x, y, z) -> null); + + JavaPairDStream pairTransformed4 = + pairStream1.transformWithToPair(pairStream2,(x, y, z) -> null); + } + + @Test + public void testStreamingContextTransform() { + List> stream1input = Arrays.asList( + Arrays.asList(1), + Arrays.asList(2) + ); + + List> stream2input = Arrays.asList( + Arrays.asList(3), + Arrays.asList(4) + ); + + List>> pairStream1input = Arrays.asList( + Arrays.asList(new Tuple2(1, "x")), + Arrays.asList(new Tuple2(2, "y")) + ); + + List>>> expected = Arrays.asList( + Arrays.asList(new Tuple2>(1, new Tuple2(1, "x"))), + Arrays.asList(new Tuple2>(2, new Tuple2(2, "y"))) + ); + + JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1); + JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, stream2input, 1); + JavaPairDStream pairStream1 = JavaPairDStream.fromJavaDStream( + JavaTestUtils.attachTestInputStream(ssc, pairStream1input, 1)); + + List> listOfDStreams1 = Arrays.>asList(stream1, stream2); + + // This is just to test whether this transform to JavaStream compiles + JavaDStream transformed1 = ssc.transform( + listOfDStreams1, (List> listOfRDDs, Time time) -> { + assert (listOfRDDs.size() == 2); + return null; + }); + + List> listOfDStreams2 = + Arrays.>asList(stream1, stream2, pairStream1.toJavaDStream()); + + JavaPairDStream> transformed2 = ssc.transformToPair( + listOfDStreams2, (List> listOfRDDs, Time time) -> { + assert (listOfRDDs.size() == 3); + JavaRDD rdd1 = (JavaRDD) listOfRDDs.get(0); + JavaRDD rdd2 = (JavaRDD) listOfRDDs.get(1); + JavaRDD> rdd3 = (JavaRDD>) listOfRDDs.get(2); + JavaPairRDD prdd3 = JavaPairRDD.fromJavaRDD(rdd3); + PairFunction mapToTuple = + (Integer i) -> new Tuple2(i, i); + return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); + }); + JavaTestUtils.attachTestOutputStream(transformed2); + List>>> result = + JavaTestUtils.runStreams(ssc, 2, 2); + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("go", "giants"), + Arrays.asList("boo", "dodgers"), + Arrays.asList("athletics")); + + List> expected = Arrays.asList( + Arrays.asList("g", "o", "g", "i", "a", "n", "t", "s"), + Arrays.asList("b", "o", "o", "d", "o", "d", "g", "e", "r", "s"), + Arrays.asList("a", "t", "h", "l", "e", "t", "i", "c", "s")); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream flatMapped = stream.flatMap(s -> Lists.newArrayList(s.split("(?!^)"))); + JavaTestUtils.attachTestOutputStream(flatMapped); + List> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + + @Test + public void testPairFlatMap() { + List> inputData = Arrays.asList( + Arrays.asList("giants"), + Arrays.asList("dodgers"), + Arrays.asList("athletics")); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(6, "g"), + new Tuple2(6, "i"), + new Tuple2(6, "a"), + new Tuple2(6, "n"), + new Tuple2(6, "t"), + new Tuple2(6, "s")), + Arrays.asList( + new Tuple2(7, "d"), + new Tuple2(7, "o"), + new Tuple2(7, "d"), + new Tuple2(7, "g"), + new Tuple2(7, "e"), + new Tuple2(7, "r"), + new Tuple2(7, "s")), + Arrays.asList( + new Tuple2(9, "a"), + new Tuple2(9, "t"), + new Tuple2(9, "h"), + new Tuple2(9, "l"), + new Tuple2(9, "e"), + new Tuple2(9, "t"), + new Tuple2(9, "i"), + new Tuple2(9, "c"), + new Tuple2(9, "s"))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream flatMapped = stream.flatMapToPair(s -> { + List> out = Lists.newArrayList(); + for (String letter : s.split("(?!^)")) { + out.add(new Tuple2(s.length(), letter)); + } + return out; + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + /* + * Performs an order-invariant comparison of lists representing two RDD streams. This allows + * us to account for ordering variation within individual RDD's which occurs during windowing. + */ + public static > void assertOrderInvariantEquals( + List> expected, List> actual) { + for (List list : expected) { + Collections.sort(list); + } + for (List list : actual) { + Collections.sort(list); + } + Assert.assertEquals(expected, actual); + } + + @Test + public void testPairFilter() { + List> inputData = Arrays.asList( + Arrays.asList("giants", "dodgers"), + Arrays.asList("yankees", "red socks")); + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("giants", 6)), + Arrays.asList(new Tuple2("yankees", 7))); + + JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = + stream.mapToPair(x -> new Tuple2<>(x, x.length())); + JavaPairDStream filtered = pairStream.filter(x -> x._1().contains("a")); + JavaTestUtils.attachTestOutputStream(filtered); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + List>> stringStringKVStream = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers"), + new Tuple2("california", "giants"), + new Tuple2("new york", "yankees"), + new Tuple2("new york", "mets")), + Arrays.asList(new Tuple2("california", "sharks"), + new Tuple2("california", "ducks"), + new Tuple2("new york", "rangers"), + new Tuple2("new york", "islanders"))); + + List>> stringIntKVStream = Arrays.asList( + Arrays.asList( + new Tuple2("california", 1), + new Tuple2("california", 3), + new Tuple2("new york", 4), + new Tuple2("new york", 1)), + Arrays.asList( + new Tuple2("california", 5), + new Tuple2("california", 5), + new Tuple2("new york", 3), + new Tuple2("new york", 1))); + + @Test + public void testPairMap() { // Maps pair -> pair of different type + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, "california"), + new Tuple2(3, "california"), + new Tuple2(4, "new york"), + new Tuple2(1, "new york")), + Arrays.asList( + new Tuple2(5, "california"), + new Tuple2(5, "california"), + new Tuple2(3, "new york"), + new Tuple2(1, "new york"))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream reversed = pairStream.mapToPair(x -> x.swap()); + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMapPartitions() { // Maps pair -> pair of different type + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, "california"), + new Tuple2(3, "california"), + new Tuple2(4, "new york"), + new Tuple2(1, "new york")), + Arrays.asList( + new Tuple2(5, "california"), + new Tuple2(5, "california"), + new Tuple2(3, "new york"), + new Tuple2(1, "new york"))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream reversed = pairStream.mapPartitionsToPair(in -> { + LinkedList> out = new LinkedList>(); + while (in.hasNext()) { + Tuple2 next = in.next(); + out.add(next.swap()); + } + return out; + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMap2() { // Maps pair -> single + List>> inputData = stringIntKVStream; + + List> expected = Arrays.asList( + Arrays.asList(1, 3, 4, 1), + Arrays.asList(5, 5, 3, 1)); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream reversed = pairStream.map(in -> in._2()); + JavaTestUtils.attachTestOutputStream(reversed); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2("hi", 1), + new Tuple2("ho", 2)), + Arrays.asList( + new Tuple2("hi", 1), + new Tuple2("ho", 2))); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, "h"), + new Tuple2(1, "i"), + new Tuple2(2, "h"), + new Tuple2(2, "o")), + Arrays.asList( + new Tuple2(1, "h"), + new Tuple2(1, "i"), + new Tuple2(2, "h"), + new Tuple2(2, "o"))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream flatMapped = pairStream.flatMapToPair(in -> { + List> out = new LinkedList>(); + for (Character s : in._1().toCharArray()) { + out.add(new Tuple2(in._2(), s.toString())); + } + return out; + }); + + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairReduceByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduced = pairStream.reduceByKey((x, y) -> x + y); + + JavaTestUtils.attachTestOutputStream(reduced); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testCombineByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList( + new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream combined = pairStream.combineByKey(i -> i, + (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); + + JavaTestUtils.attachTestOutputStream(combined); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindow() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow((x, y) -> x + y, new Duration(2000), new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testUpdateStateByKey() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream updated = pairStream.updateStateByKey((values, state) -> { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); + }); + + JavaTestUtils.attachTestOutputStream(updated); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testReduceByKeyAndWindowWithInverse() { + List>> inputData = stringIntKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 4), + new Tuple2("new york", 5)), + Arrays.asList(new Tuple2("california", 14), + new Tuple2("new york", 9)), + Arrays.asList(new Tuple2("california", 10), + new Tuple2("new york", 4))); + + JavaDStream> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream reduceWindowed = + pairStream.reduceByKeyAndWindow((x, y) -> x + y, (x, y) -> x - y, new Duration(2000), + new Duration(1000)); + JavaTestUtils.attachTestOutputStream(reduceWindowed); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairTransform() { + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2(3, 5), + new Tuple2(1, 5), + new Tuple2(4, 5), + new Tuple2(2, 5)), + Arrays.asList( + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5), + new Tuple2(1, 5))); + + List>> expected = Arrays.asList( + Arrays.asList( + new Tuple2(1, 5), + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5)), + Arrays.asList( + new Tuple2(1, 5), + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream sorted = pairStream.transformToPair(in -> in.sortByKey()); + + JavaTestUtils.attachTestOutputStream(sorted); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToNormalRDDTransform() { + List>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2(3, 5), + new Tuple2(1, 5), + new Tuple2(4, 5), + new Tuple2(2, 5)), + Arrays.asList( + new Tuple2(2, 5), + new Tuple2(3, 5), + new Tuple2(4, 5), + new Tuple2(1, 5))); + + List> expected = Arrays.asList( + Arrays.asList(3, 1, 4, 2), + Arrays.asList(2, 3, 4, 1)); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream firstParts = pairStream.transform(in -> in.map(x -> x._1())); + JavaTestUtils.attachTestOutputStream(firstParts); + List> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "DODGERS"), + new Tuple2("california", "GIANTS"), + new Tuple2("new york", "YANKEES"), + new Tuple2("new york", "METS")), + Arrays.asList(new Tuple2("california", "SHARKS"), + new Tuple2("california", "DUCKS"), + new Tuple2("new york", "RANGERS"), + new Tuple2("new york", "ISLANDERS"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream mapped = pairStream.mapValues(s -> s.toUpperCase()); + JavaTestUtils.attachTestOutputStream(mapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testFlatMapValues() { + List>> inputData = stringStringKVStream; + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", "dodgers1"), + new Tuple2("california", "dodgers2"), + new Tuple2("california", "giants1"), + new Tuple2("california", "giants2"), + new Tuple2("new york", "yankees1"), + new Tuple2("new york", "yankees2"), + new Tuple2("new york", "mets1"), + new Tuple2("new york", "mets2")), + Arrays.asList(new Tuple2("california", "sharks1"), + new Tuple2("california", "sharks2"), + new Tuple2("california", "ducks1"), + new Tuple2("california", "ducks2"), + new Tuple2("new york", "rangers1"), + new Tuple2("new york", "rangers2"), + new Tuple2("new york", "islanders1"), + new Tuple2("new york", "islanders2"))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + + JavaPairDStream flatMapped = pairStream.flatMapValues(in -> { + List out = new ArrayList(); + out.add(in + "1"); + out.add(in + "2"); + return out; + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List>> result = JavaTestUtils.runStreams(ssc, 2, 2); + Assert.assertEquals(expected, result); + } + +} diff --git a/extras/java8-tests/src/test/resources/log4j.properties b/extras/java8-tests/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..180beaa8cc5a7 --- /dev/null +++ b/extras/java8-tests/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml new file mode 100644 index 0000000000000..11ac827ed54a0 --- /dev/null +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -0,0 +1,45 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.0.0-SNAPSHOT + ../../pom.xml + + + + org.apache.spark + spark-ganglia-lgpl_2.10 + jar + Spark Ganglia Integration + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + + com.codahale.metrics + metrics-ganglia + + + diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala rename to extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index 410ca0704b5c4..cd37317da77de 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/extras/spark-ganglia-lgpl/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -24,9 +24,11 @@ import com.codahale.metrics.MetricRegistry import com.codahale.metrics.ganglia.GangliaReporter import info.ganglia.gmetric4j.gmetric.GMetric +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GangliaSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GangliaSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GANGLIA_KEY_PERIOD = "period" val GANGLIA_DEFAULT_PERIOD = 10 diff --git a/graphx/pom.xml b/graphx/pom.xml index 4823ed1d4eaec..894a7c2641e39 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml @@ -31,6 +31,20 @@ Spark Project GraphX http://spark-project.org/ + + + + yarn-alpha + + + org.apache.avro + avro + + + + + org.apache.spark diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 0fc1e4df6813c..377d9d6bd5e72 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -18,11 +18,11 @@ package org.apache.spark.graphx import scala.reflect.ClassTag - import org.apache.spark.SparkContext._ import org.apache.spark.SparkException import org.apache.spark.graphx.lib._ import org.apache.spark.rdd.RDD +import scala.util.Random /** * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the @@ -137,6 +137,42 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali } } // end of collectNeighbor + /** + * Returns an RDD that contains for each vertex v its local edges, + * i.e., the edges that are incident on v, in the user-specified direction. + * Warning: note that singleton vertices, those with no edges in the given + * direction will not be part of the return value. + * + * @note This function could be highly inefficient on power-law + * graphs where high degree vertices may force a large amount of + * information to be collected to a single location. + * + * @param edgeDirection the direction along which to collect + * the local edges of vertices + * + * @return the local edges for each vertex + */ + def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = { + edgeDirection match { + case EdgeDirection.Either => + graph.mapReduceTriplets[Array[Edge[ED]]]( + edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))), + (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), + (a, b) => a ++ b) + case EdgeDirection.In => + graph.mapReduceTriplets[Array[Edge[ED]]]( + edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), + (a, b) => a ++ b) + case EdgeDirection.Out => + graph.mapReduceTriplets[Array[Edge[ED]]]( + edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), + (a, b) => a ++ b) + case EdgeDirection.Both => + throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + + "EdgeDirection.Either instead.") + } + } + /** * Join the vertices with an RDD and then apply a function from the * the vertex and RDD entry to a new vertex value. The input table @@ -209,6 +245,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali graph.mask(preprocess(graph).subgraph(epred, vpred)) } + /** + * Picks a random vertex from the graph and returns its ID. + */ + def pickRandomVertex(): VertexId = { + val probability = 50 / graph.numVertices + var found = false + var retVal: VertexId = null.asInstanceOf[VertexId] + while (!found) { + val selectedVertices = graph.vertices.flatMap { vidVvals => + if (Random.nextDouble() < probability) { Some(vidVvals._1) } + else { None } + } + if (selectedVertices.count > 1) { + found = true + val collectedVertices = selectedVertices.collect() + retVal = collectedVertices(Random.nextInt(collectedVertices.size)) + } + } + retVal + } + /** * Execute a Pregel-like iterative vertex-parallel abstraction. The * user-defined vertex-program `vprog` is executed in parallel on diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 1d029bf009e8c..5e9be18990ba3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -391,6 +391,6 @@ object GraphImpl { // TODO: Consider doing map side distinct before shuffle. new ShuffledRDD[VertexId, Int, (VertexId, Int)]( edges.collectVertexIds.map(vid => (vid, 0)), partitioner) - .setSerializer(classOf[VertexIdMsgSerializer].getName) + .setSerializer(new VertexIdMsgSerializer) } } // end of object GraphImpl diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index e9ee09c3614c1..fe6fe76defdc5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -65,11 +65,11 @@ class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T // Set a custom serializer if the data is of int or double type. if (classTag[T] == ClassTag.Int) { - rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName) + rdd.setSerializer(new IntVertexBroadcastMsgSerializer) } else if (classTag[T] == ClassTag.Long) { - rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName) + rdd.setSerializer(new LongVertexBroadcastMsgSerializer) } else if (classTag[T] == ClassTag.Double) { - rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName) + rdd.setSerializer(new DoubleVertexBroadcastMsgSerializer) } rdd } @@ -104,11 +104,11 @@ object MsgRDDFunctions { // Set a custom serializer if the data is of int or double type. if (classTag[T] == ClassTag.Int) { - rdd.setSerializer(classOf[IntAggMsgSerializer].getName) + rdd.setSerializer(new IntAggMsgSerializer) } else if (classTag[T] == ClassTag.Long) { - rdd.setSerializer(classOf[LongAggMsgSerializer].getName) + rdd.setSerializer(new LongAggMsgSerializer) } else if (classTag[T] == ClassTag.Double) { - rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName) + rdd.setSerializer(new DoubleAggMsgSerializer) } rdd } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index c74d487e206db..34a145e01818f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -25,7 +25,7 @@ import org.apache.spark.graphx._ import org.apache.spark.serializer._ private[graphx] -class VertexIdMsgSerializer(conf: SparkConf) extends Serializer { +class VertexIdMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -46,7 +46,7 @@ class VertexIdMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for VertexBroadcastMessage[Int]. */ private[graphx] -class IntVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { +class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -70,7 +70,7 @@ class IntVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for VertexBroadcastMessage[Long]. */ private[graphx] -class LongVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { +class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -94,7 +94,7 @@ class LongVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for VertexBroadcastMessage[Double]. */ private[graphx] -class DoubleVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { +class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -118,7 +118,7 @@ class DoubleVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for AggregationMessage[Int]. */ private[graphx] -class IntAggMsgSerializer(conf: SparkConf) extends Serializer { +class IntAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -142,7 +142,7 @@ class IntAggMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for AggregationMessage[Long]. */ private[graphx] -class LongAggMsgSerializer(conf: SparkConf) extends Serializer { +class LongAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -166,7 +166,7 @@ class LongAggMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for AggregationMessage[Double]. */ private[graphx] -class DoubleAggMsgSerializer(conf: SparkConf) extends Serializer { +class DoubleAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index d1528e2f07cf2..014a7335f85cc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -23,8 +23,8 @@ import scala.collection.mutable.HashSet import org.apache.spark.util.Utils -import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} -import org.objectweb.asm.Opcodes._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor} +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ /** diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index bc2ad5677f806..6386306c048fc 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -42,21 +42,20 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test("collectNeighborIds") { withSpark { sc => - val chain = (0 until 100).map(x => (x, (x+1)%100) ) - val rawEdges = sc.parallelize(chain, 3).map { case (s,d) => (s.toLong, d.toLong) } - val graph = Graph.fromEdgeTuples(rawEdges, 1.0).cache() + val graph = getCycleGraph(sc, 100) val nbrs = graph.collectNeighborIds(EdgeDirection.Either).cache() - assert(nbrs.count === chain.size) + assert(nbrs.count === 100) assert(graph.numVertices === nbrs.count) nbrs.collect.foreach { case (vid, nbrs) => assert(nbrs.size === 2) } - nbrs.collect.foreach { case (vid, nbrs) => - val s = nbrs.toSet - assert(s.contains((vid + 1) % 100)) - assert(s.contains(if (vid > 0) vid - 1 else 99 )) + nbrs.collect.foreach { + case (vid, nbrs) => + val s = nbrs.toSet + assert(s.contains((vid + 1) % 100)) + assert(s.contains(if (vid > 0) vid - 1 else 99)) } } } - + test ("filter") { withSpark { sc => val n = 5 @@ -80,4 +79,121 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { } } + test("collectEdgesCycleDirectionOut") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.Out).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.dstId) + assert(edgeDstIds.contains((vid + 1) % 100)) + } + } + } + + test("collectEdgesCycleDirectionIn") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.In).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeSrcIds = s.map(e => e.srcId) + assert(edgeSrcIds.contains(if (vid > 0) vid - 1 else 99)) + } + } + } + + test("collectEdgesCycleDirectionEither") { + withSpark { sc => + val graph = getCycleGraph(sc, 100) + val edges = graph.collectEdges(EdgeDirection.Either).cache() + assert(edges.count == 100) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 2) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId) + assert(edgeIds.contains((vid + 1) % 100)) + assert(edgeIds.contains(if (vid > 0) vid - 1 else 99)) + } + } + } + + test("collectEdgesChainDirectionOut") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.Out).cache() + assert(edges.count == 49) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.dstId) + assert(edgeDstIds.contains(vid + 1)) + } + } + } + + test("collectEdgesChainDirectionIn") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.In).cache() + // We expect only 49 because collectEdges does not return vertices that do + // not have any edges in the specified direction. + assert(edges.count == 49) + edges.collect.foreach { case (vid, edges) => assert(edges.size == 1) } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeDstIds = s.map(e => e.srcId) + assert(edgeDstIds.contains((vid - 1) % 100)) + } + } + } + + test("collectEdgesChainDirectionEither") { + withSpark { sc => + val graph = getChainGraph(sc, 50) + val edges = graph.collectEdges(EdgeDirection.Either).cache() + // We expect only 49 because collectEdges does not return vertices that do + // not have any edges in the specified direction. + assert(edges.count === 50) + edges.collect.foreach { + case (vid, edges) => if (vid > 0 && vid < 49) assert(edges.size == 2) + else assert(edges.size == 1) + } + edges.collect.foreach { + case (vid, edges) => + val s = edges.toSet + val edgeIds = s.map(e => if (vid != e.srcId) e.srcId else e.dstId) + if (vid == 0) { assert(edgeIds.contains(1)) } + else if (vid == 49) { assert(edgeIds.contains(48)) } + else { + assert(edgeIds.contains(vid + 1)) + assert(edgeIds.contains(vid - 1)) + } + } + } + } + + private def getCycleGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = { + val cycle = (0 until numVertices).map(x => (x, (x + 1) % numVertices)) + getGraphFromSeq(sc, cycle) + } + + private def getChainGraph(sc: SparkContext, numVertices: Int): Graph[Double, Int] = { + val chain = (0 until numVertices - 1).map(x => (x, (x + 1))) + getGraphFromSeq(sc, chain) + } + + private def getGraphFromSeq(sc: SparkContext, seq: IndexedSeq[(Int, Int)]): Graph[Double, Int] = { + val rawEdges = sc.parallelize(seq, 3).map { case (s, d) => (s.toLong, d.toLong) } + Graph.fromEdgeTuples(rawEdges, 1.0).cache() + } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala index e5a582b47ba05..73438d9535962 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala @@ -32,15 +32,14 @@ import org.apache.spark.serializer.SerializationStream class SerializerSuite extends FunSuite with LocalSparkContext { test("IntVertexBroadcastMsgSerializer") { - val conf = new SparkConf(false) val outMsg = new VertexBroadcastMsg[Int](3, 4, 5) val bout = new ByteArrayOutputStream - val outStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject() val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject() assert(outMsg.vid === inMsg1.vid) @@ -54,15 +53,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("LongVertexBroadcastMsgSerializer") { - val conf = new SparkConf(false) val outMsg = new VertexBroadcastMsg[Long](3, 4, 5) val bout = new ByteArrayOutputStream - val outStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject() val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject() assert(outMsg.vid === inMsg1.vid) @@ -76,15 +74,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("DoubleVertexBroadcastMsgSerializer") { - val conf = new SparkConf(false) val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0) val bout = new ByteArrayOutputStream - val outStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject() val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject() assert(outMsg.vid === inMsg1.vid) @@ -98,15 +95,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("IntAggMsgSerializer") { - val conf = new SparkConf(false) val outMsg = (4: VertexId, 5) val bout = new ByteArrayOutputStream - val outStrm = new IntAggMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntAggMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: (VertexId, Int) = inStrm.readObject() val inMsg2: (VertexId, Int) = inStrm.readObject() assert(outMsg === inMsg1) @@ -118,15 +114,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("LongAggMsgSerializer") { - val conf = new SparkConf(false) val outMsg = (4: VertexId, 1L << 32) val bout = new ByteArrayOutputStream - val outStrm = new LongAggMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongAggMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: (VertexId, Long) = inStrm.readObject() val inMsg2: (VertexId, Long) = inStrm.readObject() assert(outMsg === inMsg1) @@ -138,15 +133,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("DoubleAggMsgSerializer") { - val conf = new SparkConf(false) val outMsg = (4: VertexId, 5.0) val bout = new ByteArrayOutputStream - val outStrm = new DoubleAggMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleAggMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: (VertexId, Double) = inStrm.readObject() val inMsg2: (VertexId, Double) = inStrm.readObject() assert(outMsg === inMsg1) diff --git a/make-distribution.sh b/make-distribution.sh index e6b5956d1e7e2..6bc6819d8da92 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -58,6 +58,7 @@ echo "Version is ${VERSION}" # Initialize defaults SPARK_HADOOP_VERSION=1.0.4 SPARK_YARN=false +SPARK_TACHYON=false MAKE_TGZ=false # Parse arguments @@ -70,6 +71,9 @@ while (( "$#" )); do --with-yarn) SPARK_YARN=true ;; + --with-tachyon) + SPARK_TACHYON=true + ;; --tgz) MAKE_TGZ=true ;; @@ -90,6 +94,12 @@ else echo "YARN disabled" fi +if [ "$SPARK_TACHYON" == "true" ]; then + echo "Tachyon Enabled" +else + echo "Tachyon Disabled" +fi + # Build fat JAR export SPARK_HADOOP_VERSION export SPARK_YARN @@ -113,6 +123,28 @@ cp -r "$FWDIR/python" "$DISTDIR" cp -r "$FWDIR/sbin" "$DISTDIR" +# Download and copy in tachyon, if requested +if [ "$SPARK_TACHYON" == "true" ]; then + TACHYON_VERSION="0.4.1" + TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" + + TMPD=`mktemp -d` + + pushd $TMPD > /dev/null + echo "Fetchting tachyon tgz" + wget "$TACHYON_URL" + + tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" + cp "tachyon-${TACHYON_VERSION}/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/jars" + mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" + cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" + cp -r "tachyon-${TACHYON_VERSION}"/src/main/java/tachyon/web/resources "$DISTDIR/tachyon/src/main/java/tachyon/web" + sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../../jars/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" + + popd > /dev/null + rm -rf $TMPD +fi + if [ "$MAKE_TGZ" == "true" ]; then TARDIR="$FWDIR/spark-$VERSION" cp -r "$DISTDIR" "$TARDIR" diff --git a/mllib/pom.xml b/mllib/pom.xml index 9a61d7c3e46c0..9b65cb4b4ce3f 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml @@ -29,7 +29,21 @@ spark-mllib_2.10 jar Spark Project ML Library - http://spark.incubator.apache.org/ + http://spark.apache.org/ + + + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala index 8803c4c1a07be..e4a26eeb07c60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala @@ -109,7 +109,7 @@ object SVD { // Construct jblas A^T A locally val ata = DoubleMatrix.zeros(n, n) - for (entry <- emits.toArray) { + for (entry <- emits.collect()) { ata.put(entry._1._1, entry._1._2, entry._2) } @@ -178,7 +178,7 @@ object SVD { val s = decomposed.S.data val v = decomposed.V.data - println("Computed " + s.toArray.length + " singular values and vectors") + println("Computed " + s.collect().length + " singular values and vectors") u.saveAsTextFile(output_u) s.saveAsTextFile(output_s) v.saveAsTextFile(output_v) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 8e87b98bac061..b967b22e818d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -149,7 +149,13 @@ object GradientDescent extends Logging { // Initialize weights as a column vector var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) - var regVal = 0.0 + + /** + * For the first iteration, the regVal will be initialized as sum of sqrt of + * weights if it's L2 update; for L1 update; the same logic is followed. + */ + var regVal = updater.compute( + weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2 for (i <- 1 to numIterations) { // Sample a subset (fraction miniBatchFraction) of the total data diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 889a03e3e61d2..bf8f731459e99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -111,6 +111,8 @@ class SquaredL2Updater extends Updater { val step = gradient.mul(thisIterStepSize) // add up both updates from the gradient of the loss (= step) as well as // the gradient of the regularizer (= regParam * weightsOld) + // w' = w - thisIterStepSize * (gradient + regParam * w) + // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step) (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 8958040e36640..0cc9f48769f83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -89,10 +89,15 @@ case class Rating(val user: Int, val product: Int, val rating: Double) * indicated user * preferences rather than explicit ratings given to items. */ -class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double, - var implicitPrefs: Boolean, var alpha: Double) - extends Serializable with Logging -{ +class ALS private ( + var numBlocks: Int, + var rank: Int, + var iterations: Int, + var lambda: Double, + var implicitPrefs: Boolean, + var alpha: Double, + var seed: Long = System.nanoTime() + ) extends Serializable with Logging { def this() = this(-1, 10, 10, 0.01, false, 1.0) /** @@ -132,13 +137,21 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l this } + /** Sets a random seed to have deterministic results. */ + def setSeed(seed: Long): ALS = { + this.seed = seed + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. */ def run(ratings: RDD[Rating]): MatrixFactorizationModel = { + val sc = ratings.context + val numBlocks = if (this.numBlocks == -1) { - math.max(ratings.context.defaultParallelism, ratings.partitions.size / 2) + math.max(sc.defaultParallelism, ratings.partitions.size / 2) } else { this.numBlocks } @@ -155,7 +168,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l // Initialize user and product factors randomly, but use a deterministic seed for each // partition so that fault recovery works - val seedGen = new Random() + val seedGen = new Random(seed) val seed1 = seedGen.nextInt() val seed2 = seedGen.nextInt() // Hash an integer to propagate random bits at all positions, similar to java.util.HashTable @@ -176,21 +189,41 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l } } - for (iter <- 1 to iterations) { - // perform ALS update - logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) - // YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model - val YtY = computeYtY(users) - val YtYb = ratings.context.broadcast(YtY) - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, - alpha, YtYb) - logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) - val XtX = computeYtY(products) - val XtXb = ratings.context.broadcast(XtX) - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, - alpha, XtXb) + if (implicitPrefs) { + for (iter <- 1 to iterations) { + // perform ALS update + logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) + // Persist users because it will be called twice. + users.persist() + val YtY = Some(sc.broadcast(computeYtY(users))) + val previousProducts = products + products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, + alpha, YtY) + previousProducts.unpersist() + logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) + products.persist() + val XtX = Some(sc.broadcast(computeYtY(products))) + val previousUsers = users + users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, + alpha, XtX) + previousUsers.unpersist() + } + } else { + for (iter <- 1 to iterations) { + // perform ALS update + logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) + products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, + alpha, YtY = None) + logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) + users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, + alpha, YtY = None) + } } + // The last `products` will be used twice. One to generate the last `users` and the other to + // generate `productsOut`. So we cache it for better performance. + products.persist() + // Flatten and cache the two final RDDs to un-block them val usersOut = unblockFactors(users, userOutLinks) val productsOut = unblockFactors(products, productOutLinks) @@ -198,37 +231,70 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l usersOut.persist() productsOut.persist() + // Materialize usersOut and productsOut. + usersOut.count() + productsOut.count() + + products.unpersist() + + // Clean up. + userInLinks.unpersist() + userOutLinks.unpersist() + productInLinks.unpersist() + productOutLinks.unpersist() + new MatrixFactorizationModel(rank, usersOut, productsOut) } /** * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors - * for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as - * the driver program requires `YtY` to broadcast it to the slaves + * for each user (or product), in a distributed fashion. + * * @param factors the (block-distributed) user or product factor vectors - * @return Option[YtY] - whose value is only used in the implicit preference model + * @return YtY - whose value is only used in the implicit preference model */ - def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = { - if (implicitPrefs) { - Option( - factors.flatMapValues { case factorArray => - factorArray.view.map { vector => - val x = new DoubleMatrix(vector) - x.mmul(x.transpose()) - } - }.reduceByKeyLocally((a, b) => a.addi(b)) - .values - .reduce((a, b) => a.addi(b)) - ) - } else { - None + private def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = { + val n = rank * (rank + 1) / 2 + val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => { + Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L)) + L + }, combOp = (L1, L2) => { + L1.addi(L2) + }) + val YtY = new DoubleMatrix(rank, rank) + fillFullMatrix(LYtY, YtY) + YtY + } + + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. + * + * @param L the lower triangular part of the matrix packed in an array (row major) + */ + private def dspr(alpha: Double, x: DoubleMatrix, L: DoubleMatrix) = { + val n = x.length + var i = 0 + var j = 0 + var idx = 0 + var axi = 0.0 + val xd = x.data + val Ld = L.data + while (i < n) { + axi = alpha * xd(i) + j = 0 + while (j <= i) { + Ld(idx) += axi * xd(j) + j += 1 + idx += 1 + } + i += 1 } } /** * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs */ - def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])], + private def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])], outLinks: RDD[(Int, OutLinkBlock)]) = { blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) => for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) @@ -296,8 +362,11 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l val outLinkBlock = makeOutLinkBlock(numBlocks, ratings) Iterator.single((blockId, (inLinkBlock, outLinkBlock))) }, true) - links.persist(StorageLevel.MEMORY_AND_DISK) - (links.mapValues(_._1), links.mapValues(_._2)) + val inLinks = links.mapValues(_._1) + val outLinks = links.mapValues(_._2) + inLinks.persist(StorageLevel.MEMORY_AND_DISK) + outLinks.persist(StorageLevel.MEMORY_AND_DISK) + (inLinks, outLinks) } /** @@ -329,7 +398,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l rank: Int, lambda: Double, alpha: Double, - YtY: Broadcast[Option[DoubleMatrix]]) + YtY: Option[Broadcast[DoubleMatrix]]) : RDD[(Int, Array[Array[Double]])] = { val numBlocks = products.partitions.size @@ -352,8 +421,8 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l * Compute the new feature vectors for a block of the users matrix given the list of factors * it received from each product and its InLinkBlock. */ - def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, - rank: Int, lambda: Double, alpha: Double, YtY: Broadcast[Option[DoubleMatrix]]) + private def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, + rank: Int, lambda: Double, alpha: Double, YtY: Option[Broadcast[DoubleMatrix]]) : Array[Array[Double]] = { // Sort the incoming block factor messages by block ID and make them an array @@ -376,59 +445,41 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l for (productBlock <- 0 until numBlocks) { for (p <- 0 until blockFactors(productBlock).length) { val x = new DoubleMatrix(blockFactors(productBlock)(p)) - fillXtX(x, tempXtX) + tempXtX.fill(0.0) + dspr(1.0, x, tempXtX) val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) for (i <- 0 until us.length) { - implicitPrefs match { - case false => - userXtX(us(i)).addi(tempXtX) - SimpleBlas.axpy(rs(i), x, userXy(us(i))) - case true => - // Extension to the original paper to handle rs(i) < 0. confidence is a function - // of |rs(i)| instead so that it is never negative: - val confidence = 1 + alpha * abs(rs(i)) - userXtX(us(i)).addi(tempXtX.mul(confidence - 1)) - // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i) - // means we try to reconstruct 0. We add terms only where P = 1, so, term below - // is now only added for rs(i) > 0: - if (rs(i) > 0) { - SimpleBlas.axpy(confidence, x, userXy(us(i))) - } + if (implicitPrefs) { + // Extension to the original paper to handle rs(i) < 0. confidence is a function + // of |rs(i)| instead so that it is never negative: + val confidence = 1 + alpha * abs(rs(i)) + SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i))) + // For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i) + // means we try to reconstruct 0. We add terms only where P = 1, so, term below + // is now only added for rs(i) > 0: + if (rs(i) > 0) { + SimpleBlas.axpy(confidence, x, userXy(us(i))) + } + } else { + userXtX(us(i)).addi(tempXtX) + SimpleBlas.axpy(rs(i), x, userXy(us(i))) } } } } // Solve the least-squares problem for each user and return the new feature vectors - userXtX.zipWithIndex.map{ case (triangularXtX, index) => + Array.range(0, numUsers).map { index => // Compute the full XtX matrix from the lower-triangular part we got above - fillFullMatrix(triangularXtX, fullXtX) + fillFullMatrix(userXtX(index), fullXtX) // Add regularization (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda) // Solve the resulting matrix, which is symmetric and positive-definite - implicitPrefs match { - case false => Solve.solvePositive(fullXtX, userXy(index)).data - case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data - } - } - } - - /** - * Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing - * these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values - * at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order. - */ - private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) { - var i = 0 - var pos = 0 - while (i < x.length) { - var j = 0 - while (j <= i) { - xtxDest.data(pos) = x.data(i) * x.data(j) - pos += 1 - j += 1 + if (implicitPrefs) { + Solve.solvePositive(fullXtX.addi(YtY.get.value), userXy(index)).data + } else { + Solve.solvePositive(fullXtX, userXy(index)).data } - i += 1 } } @@ -455,9 +506,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l /** - * Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton. + * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. */ object ALS { + /** * Train a matrix factorization model given an RDD of ratings given by users to some products, * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the @@ -470,15 +522,39 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into + * @param seed random seed */ def train( ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, - blocks: Int) - : MatrixFactorizationModel = - { + blocks: Int, + seed: Long + ): MatrixFactorizationModel = { + new ALS(blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings) + } + + /** + * Train a matrix factorization model given an RDD of ratings given by users to some products, + * in the form of (userID, productID, rating) pairs. We approximate the ratings matrix as the + * product of two lower-rank matrices of a given rank (number of features). To solve for these + * features, we run a given number of iterations of ALS. This is done using a level of + * parallelism given by `blocks`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + * @param blocks level of parallelism to split computation into + */ + def train( + ratings: RDD[Rating], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int + ): MatrixFactorizationModel = { new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings) } @@ -495,8 +571,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) */ def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) - : MatrixFactorizationModel = - { + : MatrixFactorizationModel = { train(ratings, rank, iterations, lambda, -1) } @@ -512,8 +587,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) */ def train(ratings: RDD[Rating], rank: Int, iterations: Int) - : MatrixFactorizationModel = - { + : MatrixFactorizationModel = { train(ratings, rank, iterations, 0.01, -1) } @@ -530,6 +604,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param alpha confidence parameter (only applies when immplicitPrefs = true) + * @param seed random seed */ def trainImplicit( ratings: RDD[Rating], @@ -537,9 +612,34 @@ object ALS { iterations: Int, lambda: Double, blocks: Int, - alpha: Double) - : MatrixFactorizationModel = - { + alpha: Double, + seed: Long + ): MatrixFactorizationModel = { + new ALS(blocks, rank, iterations, lambda, true, alpha, seed).run(ratings) + } + + /** + * Train a matrix factorization model given an RDD of 'implicit preferences' given by users + * to some products, in the form of (userID, productID, preference) pairs. We approximate the + * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). + * To solve for these features, we run a given number of iterations of ALS. This is done using + * a level of parallelism given by `blocks`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + * @param blocks level of parallelism to split computation into + * @param alpha confidence parameter (only applies when immplicitPrefs = true) + */ + def trainImplicit( + ratings: RDD[Rating], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int, + alpha: Double + ): MatrixFactorizationModel = { new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings) } @@ -555,8 +655,8 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) */ - def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, - alpha: Double): MatrixFactorizationModel = { + def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) + : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, lambda, -1, alpha) } @@ -573,8 +673,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) - : MatrixFactorizationModel = - { + : MatrixFactorizationModel = { trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index f98b0b536deaa..b9621530efa22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -119,7 +119,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint]) : M = { val nfeatures: Int = input.first().features.length - val initialWeights = Array.fill(nfeatures)(1.0) + val initialWeights = new Array[Double](nfeatures) run(input, initialWeights) } @@ -134,15 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } - // Add a extra variable consisting of all 1.0's for the intercept. + // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*))) + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0))) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - Array(1.0, initialWeights:_*) + initialWeights.+:(1.0) } else { initialWeights } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala index 32f3f141cd652..a92386865a189 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala @@ -50,7 +50,7 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { val m = matrix.m val n = matrix.n val ret = DoubleMatrix.zeros(m, n) - matrix.data.toArray.map(x => ret.put(x.i, x.j, x.mval)) + matrix.data.collect().map(x => ret.put(x.i, x.j, x.mval)) ret } @@ -106,7 +106,7 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { val u = decomposed.U val s = decomposed.S val v = decomposed.V - val retrank = s.data.toArray.length + val retrank = s.data.collect().length assert(retrank == 1, "rank returned not one") @@ -139,7 +139,7 @@ class SVDSuite extends FunSuite with BeforeAndAfterAll { val u = decomposed.U val s = decomposed.S val v = decomposed.V - val retrank = s.data.toArray.length + val retrank = s.data.collect().length val densea = getDenseMatrix(a) val svd = Singular.sparseSVD(densea) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index a453de6767aa2..631d0e2ad9cdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -104,4 +104,45 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa val lossDiff = loss.init.zip(loss.tail).map { case (lhs, rhs) => lhs - rhs } assert(lossDiff.count(_ > 0).toDouble / lossDiff.size > 0.8) } + + test("Test the loss and gradient of first iteration with regularization.") { + + val gradient = new LogisticGradient() + val updater = new SquaredL2Updater() + + // Add a extra variable consisting of all 1.0's for the intercept. + val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42) + val data = testData.map { case LabeledPoint(label, features) => + label -> Array(1.0, features: _*) + } + + val dataRDD = sc.parallelize(data, 2).cache() + + // Prepare non-zero weights + val initialWeightsWithIntercept = Array(1.0, 0.5) + + val regParam0 = 0 + val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD( + dataRDD, gradient, updater, 1, 1, regParam0, 1.0, initialWeightsWithIntercept) + + val regParam1 = 1 + val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD( + dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept) + + def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = { + math.abs(x - y) / (math.abs(y) + 1e-15) < tol + } + + assert(compareDouble( + loss1(0), + loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) + + math.pow(initialWeightsWithIntercept(1), 2)) / 2), + """For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""") + + assert( + compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) && + compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)), + "The different between newWeights with/without regularization " + + "should be initialWeightsWithIntercept.") + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 45e7d2db00c42..5aab9aba8f9c0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -23,9 +23,10 @@ import scala.util.Random import org.scalatest.FunSuite -import org.jblas._ +import org.jblas.DoubleMatrix import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.SparkContext._ object ALSSuite { @@ -115,6 +116,18 @@ class ALSSuite extends FunSuite with LocalSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true) } + test("pseudorandomness") { + val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2) + val model11 = ALS.train(ratings, 5, 1, 1.0, 2, 1) + val model12 = ALS.train(ratings, 5, 1, 1.0, 2, 1) + val u11 = model11.userFeatures.values.flatMap(_.toList).collect().toList + val u12 = model12.userFeatures.values.flatMap(_.toList).collect().toList + val model2 = ALS.train(ratings, 5, 1, 1.0, 2, 2) + val u2 = model2.userFeatures.values.flatMap(_.toList).collect().toList + assert(u11 == u12) + assert(u11 != u2) + } + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * diff --git a/pom.xml b/pom.xml index 3a530685b8e5a..524e5daff5388 100644 --- a/pom.xml +++ b/pom.xml @@ -25,10 +25,10 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT pom Spark Project Parent POM - http://spark.incubator.apache.org/ + http://spark.apache.org/ Apache 2.0 License @@ -37,9 +37,9 @@ - scm:git:git@github.com:apache/incubator-spark.git - scm:git:https://git-wip-us.apache.org/repos/asf/incubator-spark.git - scm:git:git@github.com:apache/incubator-spark.git + scm:git:git@github.com:apache/spark.git + scm:git:https://git-wip-us.apache.org/repos/asf/spark.git + scm:git:git@github.com:apache/spark.git HEAD @@ -49,7 +49,7 @@ matei.zaharia@gmail.com http://www.cs.berkeley.edu/~matei Apache Software Foundation - http://spark.incubator.apache.org + http://spark.apache.org @@ -64,23 +64,23 @@ Dev Mailing List - dev@spark.incubator.apache.org - dev-subscribe@spark.incubator.apache.org - dev-unsubscribe@spark.incubator.apache.org + dev@spark.apache.org + dev-subscribe@spark.apache.org + dev-unsubscribe@spark.apache.org User Mailing List - user@spark.incubator.apache.org - user-subscribe@spark.incubator.apache.org - user-unsubscribe@spark.incubator.apache.org + user@spark.apache.org + user-subscribe@spark.apache.org + user-unsubscribe@spark.apache.org Commits Mailing List - commits@spark.incubator.apache.org - commits-subscribe@spark.incubator.apache.org - commits-unsubscribe@spark.incubator.apache.org + commits@spark.apache.org + commits-subscribe@spark.apache.org + commits-unsubscribe@spark.apache.org @@ -127,7 +127,18 @@ maven-repo Maven Repository - http://repo.maven.apache.org/maven2 + https://repo.maven.apache.org/maven2 + + true + + + false + + + + apache-repo + Apache Repository + https://repository.apache.org/content/repositories/releases true @@ -138,7 +149,18 @@ jboss-repo JBoss Repository - http://repository.jboss.org/nexus/content/repositories/releases + https://repository.jboss.org/nexus/content/repositories/releases + + true + + + false + + + + mqtt-repo + MQTT Repository + https://repo.eclipse.org/content/repositories/paho-releases true @@ -150,11 +172,32 @@ cloudera-repo Cloudera Repository https://repository.cloudera.com/artifactory/cloudera-repos + + true + + + false + + + org.eclipse.jetty + jetty-util + 7.6.8.v20121106 + + + org.eclipse.jetty + jetty-security + 7.6.8.v20121106 + + + org.eclipse.jetty + jetty-plus + 7.6.8.v20121106 + org.eclipse.jetty jetty-server @@ -206,11 +249,6 @@ snappy-java 1.0.5 - - org.ow2.asm - asm - 4.0 - com.clearspring.analytics stream @@ -230,11 +268,31 @@ com.twitter chill_${scala.binary.version} 0.3.1 + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + com.twitter chill-java 0.3.1 + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + ${akka.group} @@ -295,10 +353,15 @@ mesos ${mesos.version} + + commons-net + commons-net + 2.2 + io.netty netty-all - 4.0.13.Final + 4.0.17.Final org.apache.derby @@ -310,6 +373,9 @@ net.liftweb lift-json_${scala.binary.version} 2.5.1 + org.scala-lang @@ -374,7 +440,7 @@ 3.1 test - + org.mockito mockito-all 1.8.5 @@ -393,9 +459,9 @@ test - org.apache.zookeeper - zookeeper - 3.4.5 + org.apache.curator + curator-recipes + 2.4.0 org.jboss.netty @@ -412,21 +478,32 @@ asm asm + + org.ow2.asm + asm + org.jboss.netty netty - org.codehaus.jackson - * + commons-logging + commons-logging + + + + org.apache.avro + avro + 1.7.4 + - org.sonatype.sisu.inject - * + org.jboss.netty + netty - commons-logging - commons-logging + io.netty + netty @@ -445,16 +522,12 @@ asm - org.jboss.netty - netty - - - org.codehaus.jackson - * + org.ow2.asm + asm - org.sonatype.sisu.inject - * + org.jboss.netty + netty @@ -468,16 +541,12 @@ asm - org.jboss.netty - netty - - - org.codehaus.jackson - * + org.ow2.asm + asm - org.sonatype.sisu.inject - * + org.jboss.netty + netty @@ -492,38 +561,13 @@ asm - org.jboss.netty - netty - - - org.codehaus.jackson - * - - - org.sonatype.sisu.inject - * + org.ow2.asm + asm - - - - - org.apache.avro - avro - 1.7.4 - - - org.apache.avro - avro-ipc - 1.7.4 - org.jboss.netty netty - - io.netty - netty - @@ -613,12 +657,13 @@ org.apache.maven.plugins maven-compiler-plugin - 2.5.1 + 3.1 ${java.version} ${java.version} UTF-8 1024m + true @@ -633,7 +678,7 @@ org.scalatest scalatest-maven-plugin - 1.0-M2 + 1.0-RC2 ${project.build.directory}/surefire-reports . @@ -739,10 +784,42 @@ 0.23.7 - yarn + + + + + spark-ganglia-lgpl + + extras/spark-ganglia-lgpl + + + + + java8-tests + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.4 + + + + test-jar + + + + + + + + + extras/java8-tests + @@ -758,5 +835,51 @@ + + + + hadoop-provided + + false + + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.hadoop + hadoop-yarn-api + provided + + + org.apache.hadoop + hadoop-yarn-common + provided + + + org.apache.hadoop + hadoop-yarn-client + provided + + + org.apache.avro + avro + provided + + + org.apache.avro + avro-ipc + provided + + + org.apache.zookeeper + zookeeper + provided + + + + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 39f5503452709..aebc0017c847d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,6 +23,8 @@ import AssemblyKeys._ import scala.util.Properties import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings} +import scala.collection.JavaConversions._ + // For Sonatype publishing //import com.jsuereth.pgp.sbtplugin.PgpKeys._ @@ -63,7 +65,7 @@ object SparkBuild extends Build { lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core) lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) - .dependsOn(core, graphx, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*) + .dependsOn(core, graphx, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*) dependsOn(maybeGanglia: _*) lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects") @@ -87,13 +89,28 @@ object SparkBuild extends Build { case Some(v) => v.toBoolean } lazy val hadoopClient = if (hadoopVersion.startsWith("0.20.") || hadoopVersion == "1.0.0") "hadoop-core" else "hadoop-client" - - // Conditionally include the yarn sub-project + val maybeAvro = if (hadoopVersion.startsWith("0.23.") && isYarnEnabled) Seq("org.apache.avro" % "avro" % "1.7.4") else Seq() + + // Include Ganglia integration if the user has enabled Ganglia + // This is isolated from the normal build due to LGPL-licensed code in the library + lazy val isGangliaEnabled = Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined + lazy val gangliaProj = Project("spark-ganglia-lgpl", file("extras/spark-ganglia-lgpl"), settings = gangliaSettings).dependsOn(core) + val maybeGanglia: Seq[ClasspathDependency] = if (isGangliaEnabled) Seq(gangliaProj) else Seq() + val maybeGangliaRef: Seq[ProjectReference] = if (isGangliaEnabled) Seq(gangliaProj) else Seq() + + // Include the Java 8 project if the JVM version is 8+ + lazy val javaVersion = System.getProperty("java.specification.version") + lazy val isJava8Enabled = javaVersion.toDouble >= "1.8".toDouble + val maybeJava8Tests = if (isJava8Enabled) Seq[ProjectReference](java8Tests) else Seq[ProjectReference]() + lazy val java8Tests = Project("java8-tests", file("extras/java8-tests"), settings = java8TestsSettings). + dependsOn(core) dependsOn(streaming % "compile->compile;test->test") + + // Include the YARN project if the user has enabled YARN lazy val yarnAlpha = Project("yarn-alpha", file("yarn/alpha"), settings = yarnAlphaSettings) dependsOn(core) lazy val yarn = Project("yarn", file("yarn/stable"), settings = yarnSettings) dependsOn(core) - lazy val maybeYarn = if (isYarnEnabled) Seq[ClasspathDependency](if (isNewHadoop) yarn else yarnAlpha) else Seq[ClasspathDependency]() - lazy val maybeYarnRef = if (isYarnEnabled) Seq[ProjectReference](if (isNewHadoop) yarn else yarnAlpha) else Seq[ProjectReference]() + lazy val maybeYarn: Seq[ClasspathDependency] = if (isYarnEnabled) Seq(if (isNewHadoop) yarn else yarnAlpha) else Seq() + lazy val maybeYarnRef: Seq[ProjectReference] = if (isYarnEnabled) Seq(if (isNewHadoop) yarn else yarnAlpha) else Seq() lazy val externalTwitter = Project("external-twitter", file("external/twitter"), settings = twitterSettings) .dependsOn(streaming % "compile->compile;test->test") @@ -114,22 +131,26 @@ object SparkBuild extends Build { lazy val allExternalRefs = Seq[ProjectReference](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt) lazy val examples = Project("examples", file("examples"), settings = examplesSettings) - .dependsOn(core, mllib, bagel, streaming, externalTwitter) dependsOn(allExternal: _*) + .dependsOn(core, mllib, graphx, bagel, streaming, externalTwitter) dependsOn(allExternal: _*) - // Everything except assembly, tools and examples belong to packageProjects - lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef + // Everything except assembly, tools, java8Tests and examples belong to packageProjects + lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx) ++ maybeYarnRef ++ maybeGangliaRef - lazy val allProjects = packageProjects ++ allExternalRefs ++ Seq[ProjectReference](examples, tools, assemblyProj) + lazy val allProjects = packageProjects ++ allExternalRefs ++ + Seq[ProjectReference](examples, tools, assemblyProj) ++ maybeJava8Tests def sharedSettings = Defaults.defaultSettings ++ Seq( organization := "org.apache.spark", - version := "1.0.0-incubating-SNAPSHOT", + version := "1.0.0-SNAPSHOT", scalaVersion := "2.10.3", scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-target:" + SCALAC_JVM_VERSION), javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION), unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath }, retrieveManaged := true, + javaHome := Properties.envOrNone("JAVA_HOME").map(file), + // This is to add convenience of enabling sbt -Dsbt.offline=true for making the build offline. + offline := "true".equalsIgnoreCase(sys.props("sbt.offline")), retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), @@ -138,19 +159,33 @@ object SparkBuild extends Build { fork := true, javaOptions in Test += "-Dspark.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", + javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark").map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), + // Remove certain packages from Scaladoc + scalacOptions in (Compile,doc) := Seq("-skip-packages", Seq( + "akka", + "org.apache.spark.network", + "org.apache.spark.deploy", + "org.apache.spark.util.collection" + ).mkString(":")), // Only allow one test at a time, even across projects, since they run in the same JVM concurrentRestrictions in Global += Tags.limit(Tags.Test, 1), - // also check the local Maven repository ~/.m2 - resolvers ++= Seq(Resolver.file("Local Maven Repo", file(Path.userHome + "/.m2/repository"))), - - // For Sonatype publishing - resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", - "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"), + resolvers ++= Seq( + "Maven Repository" at "https://repo.maven.apache.org/maven2", + "Apache Repository" at "https://repository.apache.org/content/repositories/releases", + "JBoss Repository" at "https://repository.jboss.org/nexus/content/repositories/releases/", + "MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/", + "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/", + // For Sonatype publishing + //"sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", + //"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/", + // also check the local Maven repository ~/.m2 + Resolver.mavenLocal + ), publishMavenStyle := true, @@ -162,7 +197,7 @@ object SparkBuild extends Build { apache 13 - http://spark.incubator.apache.org/ + http://spark.apache.org/ Apache 2.0 License @@ -171,8 +206,8 @@ object SparkBuild extends Build { - scm:git:git@github.com:apache/incubator-spark.git - scm:git:git@github.com:apache/incubator-spark.git + scm:git:git@github.com:apache/spark.git + scm:git:git@github.com:apache/spark.git @@ -181,7 +216,7 @@ object SparkBuild extends Build { matei.zaharia@gmail.com http://www.cs.berkeley.edu/~matei Apache Software Foundation - http://spark.incubator.apache.org + http://spark.apache.org @@ -202,8 +237,11 @@ object SparkBuild extends Build { */ libraryDependencies ++= Seq( - "io.netty" % "netty-all" % "4.0.13.Final", + "io.netty" % "netty-all" % "4.0.17.Final", "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106", /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */ "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"), "org.scalatest" %% "scalatest" % "1.9.1" % "test", @@ -231,20 +269,15 @@ object SparkBuild extends Build { val slf4jVersion = "1.7.5" - val excludeCglib = ExclusionRule(organization = "org.sonatype.sisu.inject") - val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson") val excludeNetty = ExclusionRule(organization = "org.jboss.netty") - val excludeAsm = ExclusionRule(organization = "asm") - val excludeSnappy = ExclusionRule(organization = "org.xerial.snappy") + val excludeAsm = ExclusionRule(organization = "org.ow2.asm") + val excludeOldAsm = ExclusionRule(organization = "asm") val excludeCommonsLogging = ExclusionRule(organization = "commons-logging") val excludeSLF4J = ExclusionRule(organization = "org.slf4j") + val excludeScalap = ExclusionRule(organization = "org.scala-lang", artifact = "scalap") def coreSettings = sharedSettings ++ Seq( name := "spark-core", - resolvers ++= Seq( - "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/", - "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/" - ), libraryDependencies ++= Seq( "com.google.guava" % "guava" % "14.0.1", @@ -278,8 +311,7 @@ object SparkBuild extends Build { "com.codahale.metrics" % "metrics-graphite" % "3.0.0", "com.twitter" %% "chill" % "0.3.1", "com.twitter" % "chill-java" % "0.3.1", - "com.clearspring.analytics" % "stream" % "2.5.1", - "org.msgpack" %% "msgpack-scala" % "0.6.8" + "com.clearspring.analytics" % "stream" % "2.5.1" ) ) @@ -298,7 +330,7 @@ object SparkBuild extends Build { name := "spark-examples", libraryDependencies ++= Seq( "com.twitter" %% "algebird-core" % "0.1.11", - "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging), + "org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeNetty, excludeAsm, excludeOldAsm, excludeCommonsLogging), "org.apache.cassandra" % "cassandra-all" % "1.2.6" exclude("com.google.guava", "guava") exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru") @@ -306,7 +338,7 @@ object SparkBuild extends Build { exclude("io.netty", "netty") exclude("jline","jline") exclude("org.apache.cassandra.deps", "avro") - excludeAll(excludeSnappy, excludeCglib, excludeSLF4J) + excludeAll(excludeSLF4J) ) ) ++ assemblySettings ++ extraAssemblySettings @@ -362,6 +394,17 @@ object SparkBuild extends Build { name := "spark-yarn" ) + def gangliaSettings = sharedSettings ++ Seq( + name := "spark-ganglia-lgpl", + libraryDependencies += "com.codahale.metrics" % "metrics-ganglia" % "3.0.0" + ) + + def java8TestsSettings = sharedSettings ++ Seq( + name := "java8-tests", + javacOptions := Seq("-target", "1.8", "-source", "1.8"), + testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a") + ) + // Conditionally include the YARN dependencies because some tools look at all sub-projects and will complain // if we refer to nonexistent dependencies (e.g. hadoop-yarn-api from a Hadoop version without YARN). def extraYarnSettings = if(isYarnEnabled) yarnEnabledSettings else Seq() @@ -369,10 +412,10 @@ object SparkBuild extends Build { def yarnEnabledSettings = Seq( libraryDependencies ++= Seq( // Exclude rule required for all ? - "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib), - "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib), - "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib), - "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib) + "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-api" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-common" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm), + "org.apache.hadoop" % "hadoop-yarn-client" % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeOldAsm) ) ) @@ -418,7 +461,7 @@ object SparkBuild extends Build { def flumeSettings() = sharedSettings ++ Seq( name := "spark-streaming-flume", libraryDependencies ++= Seq( - "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy) + "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty) ) ) @@ -431,7 +474,6 @@ object SparkBuild extends Build { def mqttSettings() = streamingSettings ++ Seq( name := "spark-streaming-mqtt", - resolvers ++= Seq("Eclipse Repo" at "https://repo.eclipse.org/content/repositories/paho-releases/"), libraryDependencies ++= Seq("org.eclipse.paho" % "mqtt-client" % "0.4.0") ) } diff --git a/project/plugins.sbt b/project/plugins.sbt index 914f2e05a402a..32bc044a93221 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -19,3 +19,4 @@ addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0") +addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.0") diff --git a/python/pyspark/context.py b/python/pyspark/context.py index d49d2677278a6..89dd614e7f023 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,6 +20,7 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile +from collections import namedtuple from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -29,6 +30,7 @@ from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, MsgPackDeserializer from pyspark.storagelevel import StorageLevel +from pyspark import rdd from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -83,6 +85,11 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ + if rdd._extract_concise_traceback() is not None: + self._callsite = rdd._extract_concise_traceback() + else: + tempNamedTuple = namedtuple("Callsite", "function file linenum") + self._callsite = tempNamedTuple(function=None, file=None, linenum=None) SparkContext._ensure_initialized(self, gateway=gateway) self.environment = environment or {} @@ -169,7 +176,14 @@ def _ensure_initialized(cls, instance=None, gateway=None): if instance: if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: - raise ValueError("Cannot run multiple SparkContexts at once") + currentMaster = SparkContext._active_spark_context.master + currentAppName = SparkContext._active_spark_context.appName + callsite = SparkContext._active_spark_context._callsite + + # Raise error if there is already a running Spark context + raise ValueError("Cannot run multiple SparkContexts at once; existing SparkContext(app=%s, master=%s)" \ + " created by %s at %s:%s " \ + % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum)) else: SparkContext._active_spark_context = instance @@ -472,6 +486,37 @@ def _getJavaStorageLevel(self, storageLevel): return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, storageLevel.deserialized, storageLevel.replication) + def setJobGroup(self, groupId, description): + """ + Assigns a group ID to all the jobs started by this thread until the group ID is set to a + different value or cleared. + + Often, a unit of execution in an application consists of multiple Spark actions or jobs. + Application programmers can use this method to group all those jobs together and give a + group description. Once set, the Spark web UI will associate such jobs with this group. + """ + self._jsc.setJobGroup(groupId, description) + + def setLocalProperty(self, key, value): + """ + Set a local property that affects jobs submitted from this thread, such as the + Spark fair scheduler pool. + """ + self._jsc.setLocalProperty(key, value) + + def getLocalProperty(self, key): + """ + Get a local property set in this thread, or null if it is missing. See + L{setLocalProperty} + """ + return self._jsc.getLocalProperty(key) + + def sparkUser(self): + """ + Get SPARK_USER for user who is running SparkContext. + """ + return self._jsc.sc().sparkUser() + def _test(): import atexit import doctest diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index c15add5237507..6a16756e0576d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -29,7 +29,7 @@ def launch_gateway(): # Launch the Py4j gateway using Spark's run command so that we pick up the - # proper classpath and SPARK_MEM settings from spark-env.sh + # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-class.cmd" if on_windows else "./bin/spark-class" command = [os.path.join(SPARK_HOME, script), "py4j.GatewayServer", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1330e6146800c..ae09dbff02a36 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,6 +18,7 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict +from collections import namedtuple from itertools import chain, ifilter, imap import operator import os @@ -28,13 +29,15 @@ from tempfile import NamedTemporaryFile from threading import Thread import warnings +from heapq import heappush, heappop, heappushpop from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, CloudPickleSerializer, pack_long + BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter from pyspark.rddsampler import RDDSampler +from pyspark.storagelevel import StorageLevel from py4j.java_collections import ListConverter, MapConverter @@ -42,12 +45,14 @@ __all__ = ["RDD"] def _extract_concise_traceback(): + """ + This function returns the traceback info for a callsite, returns a dict + with function name, file name and line number + """ tb = traceback.extract_stack() + callsite = namedtuple("Callsite", "function file linenum") if len(tb) == 0: - return "I'm lost!" - # HACK: This function is in a file called 'rdd.py' in the top level of - # everything PySpark. Just trim off the directory name and assume - # everything in that tree is PySpark guts. + return None file, line, module, what = tb[len(tb) - 1] sparkpath = os.path.dirname(file) first_spark_frame = len(tb) - 1 @@ -58,16 +63,20 @@ def _extract_concise_traceback(): break if first_spark_frame == 0: file, line, fun, what = tb[0] - return "%s at %s:%d" % (fun, file, line) + return callsite(function=fun, file=file, linenum=line) sfile, sline, sfun, swhat = tb[first_spark_frame] ufile, uline, ufun, uwhat = tb[first_spark_frame-1] - return "%s at %s:%d" % (sfun, ufile, uline) + return callsite(function=sfun, file=ufile, linenum=uline) _spark_stack_depth = 0 class _JavaStackTrace(object): def __init__(self, sc): - self._traceback = _extract_concise_traceback() + tb = _extract_concise_traceback() + if tb is not None: + self._traceback = "%s at %s:%s" % (tb.function, tb.file, tb.linenum) + else: + self._traceback = "Error! Could not extract traceback info" self._context = sc def __enter__(self): @@ -95,6 +104,13 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self.is_checkpointed = False self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer + self._id = jrdd.id() + + def id(self): + """ + A unique ID for this RDD (within its SparkContext). + """ + return self._id def __repr__(self): return self._jrdd.toString() @@ -163,7 +179,7 @@ def getCheckpointFile(self): def map(self, f, preservesPartitioning=False): """ - Return a new RDD containing the distinct elements in this RDD. + Return a new RDD by applying a function to each element of this RDD. """ def func(split, iterator): return imap(f, iterator) return PipelinedRDD(self, func, preservesPartitioning) @@ -252,6 +268,7 @@ def sample(self, withReplacement, fraction, seed): >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ + assert fraction >= 0.0, "Invalid fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala @@ -272,6 +289,9 @@ def takeSample(self, withReplacement, num, seed): if (num < 0): raise ValueError + if (initialCount == 0): + return list() + if initialCount > sys.maxint - 1: maxSelected = sys.maxint - 1 else: @@ -319,6 +339,23 @@ def union(self, other): return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, self.ctx.serializer) + def intersection(self, other): + """ + Return the intersection of this RDD and another one. The output will not + contain any duplicate elements, even if the input RDDs did. + + Note that this method performs a shuffle internally. + + >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5]) + >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8]) + >>> rdd1.intersection(rdd2).collect() + [1, 2, 3] + """ + return self.map(lambda v: (v, None)) \ + .cogroup(other.map(lambda v: (v, None))) \ + .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \ + .keys() + def _reserialize(self): if self._jrdd_deserializer == self.ctx.serializer: return self @@ -534,7 +571,26 @@ def func(iterator): return reduce(op, vals, zeroValue) # TODO: aggregate + + + def max(self): + """ + Find the maximum item in this RDD. + + >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max() + 43.0 + """ + return self.reduce(max) + def min(self): + """ + Find the maximum item in this RDD. + + >>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min() + 1.0 + """ + return self.reduce(min) + def sum(self): """ Add up the elements in this RDD. @@ -628,6 +684,30 @@ def mergeMaps(m1, m2): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) + + def top(self, num): + """ + Get the top N elements from a RDD. + + Note: It returns the list sorted in ascending order. + >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) + [12] + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2) + [5, 6] + """ + def topIterator(iterator): + q = [] + for k in iterator: + if len(q) < num: + heappush(q, k) + else: + heappushpop(q, k) + yield q + + def merge(a, b): + return next(topIterator(a + b)) + + return sorted(self.mapPartitions(topIterator).reduce(merge)) def take(self, num): """ @@ -915,7 +995,21 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) + + def foldByKey(self, zeroValue, func, numPartitions=None): + """ + Merge the values for each key using an associative function "func" and a neutral "zeroValue" + which may be added to the result an arbitrary number of times, and must not change + the result (e.g., 0 for addition, or 1 for multiplication.). + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> from operator import add + >>> rdd.foldByKey(0, add).collect() + [('a', 2), ('b', 1)] + """ + return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions) + + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ @@ -1057,6 +1151,65 @@ def coalesce(self, numPartitions, shuffle=False): jrdd = self._jrdd.coalesce(numPartitions) return RDD(jrdd, self.ctx, self._jrdd_deserializer) + def zip(self, other): + """ + Zips this RDD with another one, returning key-value pairs with the first element in each RDD + second element in each RDD, etc. Assumes that the two RDDs have the same number of + partitions and the same number of elements in each partition (e.g. one was made through + a map on the other). + + >>> x = sc.parallelize(range(0,5)) + >>> y = sc.parallelize(range(1000, 1005)) + >>> x.zip(y).collect() + [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] + """ + pairRDD = self._jrdd.zip(other._jrdd) + deserializer = PairDeserializer(self._jrdd_deserializer, + other._jrdd_deserializer) + return RDD(pairRDD, self.ctx, deserializer) + + def name(self): + """ + Return the name of this RDD. + """ + name_ = self._jrdd.name() + if not name_: + return None + return name_.encode('utf-8') + + def setName(self, name): + """ + Assign a name to this RDD. + >>> rdd1 = sc.parallelize([1,2]) + >>> rdd1.setName('RDD1') + >>> rdd1.name() + 'RDD1' + """ + self._jrdd.setName(name) + + def toDebugString(self): + """ + A description of this RDD and its recursive dependencies for debugging. + """ + debug_string = self._jrdd.toDebugString() + if not debug_string: + return None + return debug_string.encode('utf-8') + + def getStorageLevel(self): + """ + Get the RDD's current storage level. + >>> rdd1 = sc.parallelize([1,2]) + >>> rdd1.getStorageLevel() + StorageLevel(False, False, False, 1) + """ + java_storage_level = self._jrdd.getStorageLevel() + storage_level = StorageLevel(java_storage_level.useDisk(), + java_storage_level.useMemory(), + java_storage_level.deserialized(), + java_storage_level.replication()) + return storage_level + # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the # keys in the pairs. This could be an expensive operation, since those diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6fdc552e6a870..967ad4b564479 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -205,7 +205,7 @@ def __init__(self, key_ser, val_ser): self.key_ser = key_ser self.val_ser = val_ser - def load_stream(self, stream): + def prepare_keys_values(self, stream): key_stream = self.key_ser._load_stream_without_unbatching(stream) val_stream = self.val_ser._load_stream_without_unbatching(stream) key_is_batched = isinstance(self.key_ser, BatchedSerializer) @@ -213,6 +213,10 @@ def load_stream(self, stream): for (keys, vals) in izip(key_stream, val_stream): keys = keys if key_is_batched else [keys] vals = vals if val_is_batched else [vals] + yield (keys, vals) + + def load_stream(self, stream): + for (keys, vals) in self.prepare_keys_values(stream): for pair in product(keys, vals): yield pair @@ -225,6 +229,29 @@ def __str__(self): (str(self.key_ser), str(self.val_ser)) +class PairDeserializer(CartesianDeserializer): + """ + Deserializes the JavaRDD zip() of two PythonRDDs. + """ + + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def load_stream(self, stream): + for (keys, vals) in self.prepare_keys_values(stream): + for pair in izip(keys, vals): + yield pair + + def __eq__(self, other): + return isinstance(other, PairDeserializer) and \ + self.key_ser == other.key_ser and self.val_ser == other.val_ser + + def __str__(self): + return "PairDeserializer<%s, %s>" % \ + (str(self.key_ser), str(self.val_ser)) + + class NoOpSerializer(FramedSerializer): def loads(self, obj): return obj diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 8e1cbd4ad9856..080325061a697 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -26,7 +26,9 @@ def __init__(self, values=[]): self.n = 0L # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) - + self.maxValue = float("-inf") + self.minValue = float("inf") + for v in values: self.merge(v) @@ -36,6 +38,11 @@ def merge(self, value): self.n += 1 self.mu += delta / self.n self.m2 += delta * (value - self.mu) + if self.maxValue < value: + self.maxValue = value + if self.minValue > value: + self.minValue = value + return self # Merge another StatCounter into this one, adding up the internal statistics. @@ -49,7 +56,10 @@ def mergeStats(self, other): if self.n == 0: self.mu = other.mu self.m2 = other.m2 - self.n = other.n + self.n = other.n + self.maxValue = other.maxValue + self.minValue = other.minValue + elif other.n != 0: delta = other.mu - self.mu if other.n * 10 < self.n: @@ -58,6 +68,9 @@ def mergeStats(self, other): self.mu = other.mu - (delta * self.n) / (self.n + other.n) else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) + + self.maxValue = max(self.maxValue, other.maxValue) + self.minValue = min(self.minValue, other.minValue) self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n @@ -76,6 +89,12 @@ def mean(self): def sum(self): return self.n * self.mu + def min(self): + return self.minValue + + def max(self): + return self.maxValue + # Return the variance of the values. def variance(self): if self.n == 0: @@ -105,5 +124,5 @@ def sampleStdev(self): return math.sqrt(self.sampleVariance()) def __repr__(self): - return "(count: %s, mean: %s, stdev: %s)" % (self.count(), self.mean(), self.stdev()) + return "(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % (self.count(), self.mean(), self.stdev(), self.max(), self.min()) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index b31f4762e69bc..c3e3a44e8e7ab 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -31,6 +31,10 @@ def __init__(self, useDisk, useMemory, deserialized, replication = 1): self.deserialized = deserialized self.replication = replication + def __repr__(self): + return "StorageLevel(%s, %s, %s, %s)" % ( + self.useDisk, self.useMemory, self.deserialized, self.replication) + StorageLevel.DISK_ONLY = StorageLevel(True, False, False) StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, 2) StorageLevel.MEMORY_ONLY = StorageLevel(False, True, True) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 158646352039f..4c214ef359685 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -45,34 +45,34 @@ def report_times(outfile, boot, init, finish): def main(infile, outfile): - boot_time = time.time() - split_index = read_int(infile) - if split_index == -1: # for unit tests - return + try: + boot_time = time.time() + split_index = read_int(infile) + if split_index == -1: # for unit tests + return - # fetch name of workdir - spark_files_dir = utf8_deserializer.loads(infile) - SparkFiles._root_directory = spark_files_dir - SparkFiles._is_running_on_worker = True + # fetch name of workdir + spark_files_dir = utf8_deserializer.loads(infile) + SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True - # fetch names and values of broadcast variables - num_broadcast_variables = read_int(infile) - for _ in range(num_broadcast_variables): - bid = read_long(infile) - value = pickleSer._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) + # fetch names and values of broadcast variables + num_broadcast_variables = read_int(infile) + for _ in range(num_broadcast_variables): + bid = read_long(infile) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) - # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here - num_python_includes = read_int(infile) - for _ in range(num_python_includes): - filename = utf8_deserializer.loads(infile) - sys.path.append(os.path.join(spark_files_dir, filename)) + # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH + sys.path.append(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) + for _ in range(num_python_includes): + filename = utf8_deserializer.loads(infile) + sys.path.append(os.path.join(spark_files_dir, filename)) - command = pickleSer._read_with_length(infile) - (func, deserializer, serializer) = command - init_time = time.time() - try: + command = pickleSer._read_with_length(infile) + (func, deserializer, serializer) = command + init_time = time.time() iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: diff --git a/repl/pom.xml b/repl/pom.xml index 73597f635b9e0..fc49c8b811316 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml @@ -29,7 +29,21 @@ spark-repl_2.10 jar Spark Project REPL - http://spark.incubator.apache.org/ + http://spark.apache.org/ + + + + + yarn-alpha + + + org.apache.avro + avro + + + + /usr/share/spark @@ -98,7 +112,7 @@ true - + @@ -111,7 +125,7 @@ - + diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index e3bcf7f30ac8d..ee972887feda6 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -18,14 +18,18 @@ package org.apache.spark.repl import java.io.{ByteArrayOutputStream, InputStream} -import java.net.{URI, URL, URLClassLoader, URLEncoder} +import java.net.{URI, URL, URLEncoder} import java.util.concurrent.{Executors, ExecutorService} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} -import org.objectweb.asm._ -import org.objectweb.asm.Opcodes._ +import org.apache.spark.SparkEnv +import org.apache.spark.util.Utils + + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ /** @@ -53,7 +57,13 @@ extends ClassLoader(parent) { if (fileSystem != null) { fileSystem.open(new Path(directory, pathInDirectory)) } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) + newuri.toURL().openStream() + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } } } val bytes = readAndTransformClass(name, inputStream) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 013cea07d48fd..9b1da195002c2 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -182,8 +182,13 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, /** Create a new interpreter. */ def createInterpreter() { - if (addedClasspath != "") - settings.classpath append addedClasspath + require(settings != null) + + if (addedClasspath != "") settings.classpath.append(addedClasspath) + // work around for Scala bug + val totalClassPath = SparkILoop.getAddedJars.foldLeft( + settings.classpath.value)((l, r) => ClassPath.join(l, r)) + this.settings.classpath.value = totalClassPath intp = new SparkILoopInterpreter } @@ -876,6 +881,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, }) def process(settings: Settings): Boolean = savingContextLoader { + if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + this.settings = settings createInterpreter() @@ -934,16 +941,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") - val master = this.master match { - case Some(m) => m - case None => { - val prop = System.getenv("MASTER") - if (prop != null) prop else "local" - } - } val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath) val conf = new SparkConf() - .setMaster(master) + .setMaster(getMaster()) .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", intp.classServer.uri) @@ -958,6 +958,17 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, sparkContext } + private def getMaster(): String = { + val master = this.master match { + case Some(m) => m + case None => { + val prop = System.getenv("MASTER") + if (prop != null) prop else "local" + } + } + master + } + /** process command-line arguments and do as they request */ def process(args: Array[String]): Boolean = { val command = new SparkCommandLine(args.toList, msg => echo(msg)) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1d73d0b6993a8..90a96ad38381e 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -36,7 +36,7 @@ import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable import util.stackTraceString -import org.apache.spark.{HttpServer, SparkConf, Logging} +import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils // /** directory to save .class files to */ @@ -83,15 +83,17 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) extends SparkImports with Logging { + class SparkIMain(initialSettings: Settings, val out: JPrintWriter) + extends SparkImports with Logging { imain => - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + val conf = new SparkConf() + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ val outputDir = { val tmp = System.getProperty("java.io.tmpdir") - val rootDir = new SparkConf().get("spark.repl.classdir", tmp) + val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) } if (SPARK_DEBUG_REPL) { @@ -99,7 +101,8 @@ import org.apache.spark.util.Utils } val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - val classServer = new HttpServer(outputDir) /** Jetty server that will serve our classes to worker nodes */ + val classServer = new HttpServer(outputDir, + new SecurityManager(conf)) /** Jetty server that will serve our classes to worker nodes */ private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 2daf49db359df..5c89ab4d86b3a 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -24,11 +24,22 @@ sbin=`dirname "$0"` sbin=`cd "$sbin"; pwd` +TACHYON_STR="" + +while (( "$#" )); do +case $1 in + --with-tachyon) + TACHYON_STR="--with-tachyon" + ;; + esac +shift +done + # Load the Spark configuration . "$sbin/spark-config.sh" # Start Master -"$sbin"/start-master.sh +"$sbin"/start-master.sh $TACHYON_STR # Start Workers -"$sbin"/start-slaves.sh +"$sbin"/start-slaves.sh $TACHYON_STR diff --git a/sbin/start-master.sh b/sbin/start-master.sh index ec3dfdb4197ec..03a3428aea9f1 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -22,6 +22,21 @@ sbin=`dirname "$0"` sbin=`cd "$sbin"; pwd` +START_TACHYON=false + +while (( "$#" )); do +case $1 in + --with-tachyon) + if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + echo "Error: --with-tachyon specified, but tachyon not found." + exit -1 + fi + START_TACHYON=true + ;; + esac +shift +done + . "$sbin/spark-config.sh" if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then @@ -41,3 +56,9 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then fi "$sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT + +if [ "$START_TACHYON" == "true" ]; then + "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "$sbin"/../tachyon/bin/tachyon format -s + "$sbin"/../tachyon/bin/tachyon-start.sh master +fi diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index fd5cdeb1e6788..da641cfe3c6fa 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -20,6 +20,22 @@ sbin=`dirname "$0"` sbin=`cd "$sbin"; pwd` + +START_TACHYON=false + +while (( "$#" )); do +case $1 in + --with-tachyon) + if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + echo "Error: --with-tachyon specified, but tachyon not found." + exit -1 + fi + START_TACHYON=true + ;; + esac +shift +done + . "$sbin/spark-config.sh" if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then @@ -35,6 +51,13 @@ if [ "$SPARK_MASTER_IP" = "" ]; then SPARK_MASTER_IP=`hostname` fi +if [ "$START_TACHYON" == "true" ]; then + "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + + # set -t so we can call sudo + SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/../tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 +fi + # Launch the slaves if [ "$SPARK_WORKER_INSTANCES" = "" ]; then exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT diff --git a/sbin/stop-master.sh b/sbin/stop-master.sh index 2adabd426563c..b6bdaa4db373c 100755 --- a/sbin/stop-master.sh +++ b/sbin/stop-master.sh @@ -25,3 +25,7 @@ sbin=`cd "$sbin"; pwd` . "$sbin/spark-config.sh" "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1 + +if [ -e "$sbin"/../tachyon/bin/tachyon ]; then + "$sbin"/../tachyon/bin/tachyon killAll tachyon.master.Master +fi diff --git a/sbin/stop-slaves.sh b/sbin/stop-slaves.sh index eb803b4900347..6bf393ccd4b09 100755 --- a/sbin/stop-slaves.sh +++ b/sbin/stop-slaves.sh @@ -26,6 +26,11 @@ if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then . "${SPARK_CONF_DIR}/spark-env.sh" fi +# do before the below calls as they exec +if [ -e "$sbin"/../tachyon/bin/tachyon ]; then + "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon killAll tachyon.worker.Worker +fi + if [ "$SPARK_WORKER_INSTANCES" = "" ]; then "$sbin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1 else diff --git a/sbt/sbt b/sbt/sbt index 8472dce589bcc..3ffa4ed9ab5a7 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -1,51 +1,102 @@ -#!/bin/bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# This script launches sbt for this project. If present it uses the system -# version of sbt. If there is no system version of sbt it attempts to download -# sbt locally. -SBT_VERSION=`awk -F "=" '/sbt\\.version/ {print $2}' ./project/build.properties` -URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar -URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar -JAR=sbt/sbt-launch-${SBT_VERSION}.jar - -# Download sbt launch jar if it hasn't been downloaded yet -if [ ! -f ${JAR} ]; then - # Download - printf "Attempting to fetch sbt\n" - JAR_DL=${JAR}.part - if hash curl 2>/dev/null; then - (curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR} - elif hash wget 2>/dev/null; then - (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR} - else - printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" - exit -1 - fi -fi -if [ ! -f ${JAR} ]; then - # We failed to download - printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" - exit -1 -fi -printf "Launching sbt from ${JAR}\n" -java \ - -Xmx1200m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m \ - -jar ${JAR} \ - "$@" +#!/usr/bin/env bash + +realpath () { +( + TARGET_FILE=$1 + + cd $(dirname $TARGET_FILE) + TARGET_FILE=$(basename $TARGET_FILE) + + COUNT=0 + while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] + do + TARGET_FILE=$(readlink $TARGET_FILE) + cd $(dirname $TARGET_FILE) + TARGET_FILE=$(basename $TARGET_FILE) + COUNT=$(($COUNT + 1)) + done + + echo $(pwd -P)/$TARGET_FILE +) +} + +. $(dirname $(realpath $0))/sbt-launch-lib.bash + + +declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" +declare -r sbt_opts_file=".sbtopts" +declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" + +usage() { + cat < path to global settings/plugins directory (default: ~/.sbt) + -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) + -ivy path to local Ivy repository (default: ~/.ivy2) + -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem)) + -no-share use all local caches; no sharing + -no-global uses global caches, but does not use global ~/.sbt directory. + -jvm-debug Turn on JVM debugging, open at the given port. + -batch Disable interactive mode + + # sbt version (default: from project/build.properties if present, else latest release) + -sbt-version use the specified version of sbt + -sbt-jar use the specified jar as the sbt launcher + -sbt-rc use an RC version of sbt + -sbt-snapshot use a snapshot version of sbt + + # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) + -java-home alternate JAVA_HOME + + # jvm options and output control + JAVA_OPTS environment variable, if unset uses "$java_opts" + SBT_OPTS environment variable, if unset uses "$default_sbt_opts" + .sbtopts if this file exists in the current directory, it is + prepended to the runner args + /etc/sbt/sbtopts if this file exists, it is prepended to the runner args + -Dkey=val pass -Dkey=val directly to the java runtime + -J-X pass option -X directly to the java runtime + (-J is stripped) + -S-X add -X to sbt's scalacOptions (-J is stripped) + +In the case of duplicated or conflicting options, the order above +shows precedence: JAVA_OPTS lowest, command line options highest. +EOM +} + +process_my_args () { + while [[ $# -gt 0 ]]; do + case "$1" in + -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; + -no-share) addJava "$noshare_opts" && shift ;; + -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; + -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; + -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; + -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; + -batch) exec &2 "$@" +} +vlog () { + [[ $verbose || $debug ]] && echoerr "$@" +} +dlog () { + [[ $debug ]] && echoerr "$@" +} + +acquire_sbt_jar () { + SBT_VERSION=`awk -F "=" '/sbt\\.version/ {print $2}' ./project/build.properties` + URL1=http://typesafe.artifactoryonline.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar + URL2=http://repo.typesafe.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar + JAR=sbt/sbt-launch-${SBT_VERSION}.jar + + sbt_jar=$JAR + + if [[ ! -f "$sbt_jar" ]]; then + # Download sbt launch jar if it hasn't been downloaded yet + if [ ! -f ${JAR} ]; then + # Download + printf "Attempting to fetch sbt\n" + JAR_DL=${JAR}.part + if hash curl 2>/dev/null; then + (curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR} + elif hash wget 2>/dev/null; then + (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR} + else + printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" + exit -1 + fi + fi + if [ ! -f ${JAR} ]; then + # We failed to download + printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" + exit -1 + fi + printf "Launching sbt from ${JAR}\n" + fi +} + +execRunner () { + # print the arguments one to a line, quoting any containing spaces + [[ $verbose || $debug ]] && echo "# Executing command line:" && { + for arg; do + if printf "%s\n" "$arg" | grep -q ' '; then + printf "\"%s\"\n" "$arg" + else + printf "%s\n" "$arg" + fi + done + echo "" + } + + exec "$@" +} + +addJava () { + dlog "[addJava] arg = '$1'" + java_args=( "${java_args[@]}" "$1" ) +} +addSbt () { + dlog "[addSbt] arg = '$1'" + sbt_commands=( "${sbt_commands[@]}" "$1" ) +} +addResidual () { + dlog "[residual] arg = '$1'" + residual_args=( "${residual_args[@]}" "$1" ) +} +addDebugger () { + addJava "-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1" +} + +# a ham-fisted attempt to move some memory settings in concert +# so they need not be dicked around with individually. +get_mem_opts () { + local mem=${1:-2048} + local perm=$(( $mem / 4 )) + (( $perm > 256 )) || perm=256 + (( $perm < 4096 )) || perm=4096 + local codecache=$(( $perm / 2 )) + + echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m" +} + +require_arg () { + local type="$1" + local opt="$2" + local arg="$3" + if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then + die "$opt requires <$type> argument" + fi +} + +is_function_defined() { + declare -f "$1" > /dev/null +} + +process_args () { + while [[ $# -gt 0 ]]; do + case "$1" in + -h|-help) usage; exit 1 ;; + -v|-verbose) verbose=1 && shift ;; + -d|-debug) debug=1 && shift ;; + + -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; + -mem) require_arg integer "$1" "$2" && sbt_mem="$2" && shift 2 ;; + -jvm-debug) require_arg port "$1" "$2" && addDebugger $2 && shift 2 ;; + -batch) exec org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml @@ -29,21 +29,21 @@ spark-streaming_2.10 jar Spark Project Streaming - http://spark.incubator.apache.org/ + http://spark.apache.org/ - - - apache-repo - Apache Repository - https://repository.apache.org/content/repositories/releases - - true - - - false - - - + + + + yarn-alpha + + + org.apache.avro + avro + + + + @@ -55,11 +55,6 @@ org.eclipse.jetty jetty-server - - org.codehaus.jackson - jackson-mapper-asl - 1.9.11 - org.scala-lang scala-library diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index e23b725052864..721d50273259e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -41,7 +41,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T] /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = - dstream.filter((x => f(x).booleanValue())) + dstream.filter((x => f.call(x).booleanValue())) /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): JavaDStream[T] = dstream.cache() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 7aa7ead29b469..a85cd04c9319c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,19 +17,20 @@ package org.apache.spark.streaming.api.java -import java.util.{List => JList} +import java.util import java.lang.{Long => JLong} +import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import org.apache.spark.streaming._ -import org.apache.spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD} -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import org.apache.spark.api.java.function.{Function3 => JFunction3, _} -import java.util +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaRDDLike} +import org.apache.spark.api.java.JavaPairRDD._ +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.rdd.RDD -import JavaDStream._ +import org.apache.spark.streaming._ +import org.apache.spark.streaming.api.java.JavaDStream._ import org.apache.spark.streaming.dstream.DStream trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] @@ -123,23 +124,23 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * this DStream. Applying glom() to an RDD coalesces all elements within each partition into * an array. */ - def glom(): JavaDStream[JList[T]] = { + def glom(): JavaDStream[JList[T]] = new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq))) - } + /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */ - def context(): StreamingContext = dstream.context() + def context(): StreamingContext = dstream.context /** Return a new DStream by applying a function to all elements of this DStream. */ def map[R](f: JFunction[T, R]): JavaDStream[R] = { - new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType()) + new JavaDStream(dstream.map(f)(fakeClassTag))(fakeClassTag) } /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { - def cm = implicitly[ClassTag[Tuple2[_, _]]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] - new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) + def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { + def cm: ClassTag[(K2, V2)] = fakeClassTag + new JavaPairDStream(dstream.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } /** @@ -148,19 +149,19 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T */ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - new JavaDStream(dstream.flatMap(fn)(f.elementType()))(f.elementType()) + def fn = (x: T) => f.call(x).asScala + new JavaDStream(dstream.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } /** * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { + def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ - def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassTag[Tuple2[_, _]]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] - new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType()) + def fn = (x: T) => f.call(x).asScala + def cm: ClassTag[(K2, V2)] = fakeClassTag + new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) } /** @@ -169,8 +170,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of the RDD. */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType()) + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + new JavaDStream(dstream.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U]) } /** @@ -178,10 +179,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. */ - def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) + def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) : JavaPairDStream[K2, V2] = { - def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) - new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType()) + def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator()) + new JavaPairDStream(dstream.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2]) } /** @@ -283,8 +284,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * on each RDD of 'this' DStream. */ def transform[U](transformFunc: JFunction[R, JavaRDD[U]]): JavaDStream[U] = { - implicit val cm: ClassTag[U] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] + implicit val cm: ClassTag[U] = fakeClassTag + def scalaTransform (in: RDD[T]): RDD[U] = transformFunc.call(wrapRDD(in)).rdd dstream.transform(scalaTransform(_)) @@ -295,8 +296,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * on each RDD of 'this' DStream. */ def transform[U](transformFunc: JFunction2[R, Time, JavaRDD[U]]): JavaDStream[U] = { - implicit val cm: ClassTag[U] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] + implicit val cm: ClassTag[U] = fakeClassTag + def scalaTransform (in: RDD[T], time: Time): RDD[U] = transformFunc.call(wrapRDD(in), time).rdd dstream.transform(scalaTransform(_, _)) @@ -306,12 +307,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]): + def transformToPair[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]): JavaPairDStream[K2, V2] = { - implicit val cmk: ClassTag[K2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]] - implicit val cmv: ClassTag[V2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]] + implicit val cmk: ClassTag[K2] = fakeClassTag + implicit val cmv: ClassTag[V2] = fakeClassTag + def scalaTransform (in: RDD[T]): RDD[(K2, V2)] = transformFunc.call(wrapRDD(in)).rdd dstream.transform(scalaTransform(_)) @@ -321,12 +321,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]): + def transformToPair[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]): JavaPairDStream[K2, V2] = { - implicit val cmk: ClassTag[K2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]] - implicit val cmv: ClassTag[V2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]] + implicit val cmk: ClassTag[K2] = fakeClassTag + implicit val cmv: ClassTag[V2] = fakeClassTag + def scalaTransform (in: RDD[T], time: Time): RDD[(K2, V2)] = transformFunc.call(wrapRDD(in), time).rdd dstream.transform(scalaTransform(_, _)) @@ -340,10 +339,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T other: JavaDStream[U], transformFunc: JFunction3[R, JavaRDD[U], Time, JavaRDD[W]] ): JavaDStream[W] = { - implicit val cmu: ClassTag[U] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] - implicit val cmv: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cmu: ClassTag[U] = fakeClassTag + implicit val cmv: ClassTag[W] = fakeClassTag + def scalaTransform (inThis: RDD[T], inThat: RDD[U], time: Time): RDD[W] = transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd dstream.transformWith[U, W](other.dstream, scalaTransform(_, _, _)) @@ -353,16 +351,13 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream and 'other' DStream. */ - def transformWith[U, K2, V2]( + def transformWithToPair[U, K2, V2]( other: JavaDStream[U], transformFunc: JFunction3[R, JavaRDD[U], Time, JavaPairRDD[K2, V2]] ): JavaPairDStream[K2, V2] = { - implicit val cmu: ClassTag[U] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] - implicit val cmk2: ClassTag[K2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]] - implicit val cmv2: ClassTag[V2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]] + implicit val cmu: ClassTag[U] = fakeClassTag + implicit val cmk2: ClassTag[K2] = fakeClassTag + implicit val cmv2: ClassTag[V2] = fakeClassTag def scalaTransform (inThis: RDD[T], inThat: RDD[U], time: Time): RDD[(K2, V2)] = transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd dstream.transformWith[U, (K2, V2)](other.dstream, scalaTransform(_, _, _)) @@ -376,12 +371,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T other: JavaPairDStream[K2, V2], transformFunc: JFunction3[R, JavaPairRDD[K2, V2], Time, JavaRDD[W]] ): JavaDStream[W] = { - implicit val cmk2: ClassTag[K2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]] - implicit val cmv2: ClassTag[V2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]] - implicit val cmw: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cmk2: ClassTag[K2] = fakeClassTag + implicit val cmv2: ClassTag[V2] = fakeClassTag + implicit val cmw: ClassTag[W] = fakeClassTag + def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] = transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _)) @@ -391,18 +384,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream and 'other' DStream. */ - def transformWith[K2, V2, K3, V3]( + def transformWithToPair[K2, V2, K3, V3]( other: JavaPairDStream[K2, V2], transformFunc: JFunction3[R, JavaPairRDD[K2, V2], Time, JavaPairRDD[K3, V3]] ): JavaPairDStream[K3, V3] = { - implicit val cmk2: ClassTag[K2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]] - implicit val cmv2: ClassTag[V2] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]] - implicit val cmk3: ClassTag[K3] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K3]] - implicit val cmv3: ClassTag[V3] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V3]] + implicit val cmk2: ClassTag[K2] = fakeClassTag + implicit val cmv2: ClassTag[V2] = fakeClassTag + implicit val cmk3: ClassTag[K3] = fakeClassTag + implicit val cmv3: ClassTag[V3] = fakeClassTag def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[(K3, V3)] = transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd dstream.transformWith[(K2, V2), (K3, V3)](other.dstream, scalaTransform(_, _, _)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 4dcd0e4c51ec3..ac451d1913aaa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -17,24 +17,25 @@ package org.apache.spark.streaming.api.java -import java.util.{List => JList} import java.lang.{Long => JLong} +import java.util.{List => JList} import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import org.apache.spark.streaming._ -import org.apache.spark.streaming.StreamingContext._ -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3} -import org.apache.spark.Partitioner +import com.google.common.base.Optional +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.conf.Configuration -import org.apache.spark.api.java.{JavaUtils, JavaRDD, JavaPairRDD} -import org.apache.spark.storage.StorageLevel -import com.google.common.base.Optional +import org.apache.spark.Partitioner +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.JavaPairRDD._ +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PairRDDFunctions +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream /** @@ -54,7 +55,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = - dstream.filter((x => f(x).booleanValue())) + dstream.filter((x => f.call(x).booleanValue())) /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): JavaPairDStream[K, V] = dstream.cache() @@ -127,7 +128,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. * Therefore, the values for each key in `this` DStream's RDDs are grouped into a - * single sequence to generate the RDDs of the new DStream. [[org.apache.spark.Partitioner]] + * single sequence to generate the RDDs of the new DStream. org.apache.spark.Partitioner * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = @@ -151,7 +152,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the supplied reduce function. [[org.apache.spark.Partitioner]] is used to control + * merged using the supplied reduce function. org.apache.spark.Partitioner is used to control * thepartitioning of each RDD. */ def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = { @@ -161,22 +162,21 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * Combine elements of each key in DStream's RDDs using custom function. This is similar to the * combineByKey for RDDs. Please refer to combineByKey in - * [[org.apache.spark.rdd.PairRDDFunctions]] for more information. + * org.apache.spark.rdd.PairRDDFunctions for more information. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], partitioner: Partitioner ): JavaPairDStream[K, C] = { - implicit val cm: ClassTag[C] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] + implicit val cm: ClassTag[C] = fakeClassTag dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner) } /** * Combine elements of each key in DStream's RDDs using custom function. This is similar to the * combineByKey for RDDs. Please refer to combineByKey in - * [[org.apache.spark.rdd.PairRDDFunctions]] for more information. + * org.apache.spark.rdd.PairRDDFunctions for more information. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -184,8 +184,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( partitioner: Partitioner, mapSideCombine: Boolean ): JavaPairDStream[K, C] = { - implicit val cm: ClassTag[C] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] + implicit val cm: ClassTag[C] = fakeClassTag dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine) } @@ -279,7 +278,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * DStream's batching interval */ def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], + reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration ):JavaPairDStream[K, V] = { @@ -299,7 +298,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param numPartitions Number of partitions of each RDD in the new DStream. */ def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], + reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration, numPartitions: Int @@ -320,7 +319,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * DStream. */ def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], + reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner @@ -345,8 +344,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * DStream's batching interval */ def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - invReduceFunc: Function2[V, V, V], + reduceFunc: JFunction2[V, V, V], + invReduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration ): JavaPairDStream[K, V] = { @@ -374,8 +373,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * set this to null if you do not want to filter */ def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - invReduceFunc: Function2[V, V, V], + reduceFunc: JFunction2[V, V, V], + invReduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration, numPartitions: Int, @@ -412,8 +411,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * set this to null if you do not want to filter */ def reduceByKeyAndWindow( - reduceFunc: Function2[V, V, V], - invReduceFunc: Function2[V, V, V], + reduceFunc: JFunction2[V, V, V], + invReduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner, @@ -453,8 +452,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( */ def updateStateByKey[S](updateFunc: JFunction2[JList[V], Optional[S], Optional[S]]) : JavaPairDStream[K, S] = { - implicit val cm: ClassTag[S] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]] + implicit val cm: ClassTag[S] = fakeClassTag dstream.updateStateByKey(convertUpdateStateFunction(updateFunc)) } @@ -471,15 +469,14 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], numPartitions: Int) : JavaPairDStream[K, S] = { - implicit val cm: ClassTag[S] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]] + implicit val cm: ClassTag[S] = fakeClassTag dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions) } /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. - * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new @@ -490,8 +487,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], partitioner: Partitioner ): JavaPairDStream[K, S] = { - implicit val cm: ClassTag[S] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]] + implicit val cm: ClassTag[S] = fakeClassTag dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner) } @@ -501,8 +497,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * 'this' DStream without changing the key. */ def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { - implicit val cm: ClassTag[U] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] + implicit val cm: ClassTag[U] = fakeClassTag dstream.mapValues(f) } @@ -524,8 +519,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * of partitions. */ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) } @@ -537,8 +531,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( other: JavaPairDStream[K, W], numPartitions: Int ): JavaPairDStream[K, (JList[V], JList[W])] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag dstream.cogroup(other.dstream, numPartitions) .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) } @@ -551,8 +544,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( other: JavaPairDStream[K, W], partitioner: Partitioner ): JavaPairDStream[K, (JList[V], JList[W])] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag dstream.cogroup(other.dstream, partitioner) .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) } @@ -562,8 +554,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. */ def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag dstream.join(other.dstream) } @@ -572,21 +563,19 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ def join[W](other: JavaPairDStream[K, W], numPartitions: Int): JavaPairDStream[K, (V, W)] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag dstream.join(other.dstream, numPartitions) } /** * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. - * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * The supplied org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ def join[W]( other: JavaPairDStream[K, W], partitioner: Partitioner ): JavaPairDStream[K, (V, W)] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag dstream.join(other.dstream, partitioner) } @@ -596,8 +585,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * number of partitions. */ def leftOuterJoin[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, Optional[W])] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag val joinResult = dstream.leftOuterJoin(other.dstream) joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))} } @@ -611,22 +599,20 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( other: JavaPairDStream[K, W], numPartitions: Int ): JavaPairDStream[K, (V, Optional[W])] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag val joinResult = dstream.leftOuterJoin(other.dstream, numPartitions) joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))} } /** * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. - * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * The supplied org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ def leftOuterJoin[W]( other: JavaPairDStream[K, W], partitioner: Partitioner ): JavaPairDStream[K, (V, Optional[W])] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag val joinResult = dstream.leftOuterJoin(other.dstream, partitioner) joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))} } @@ -652,23 +638,21 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( other: JavaPairDStream[K, W], numPartitions: Int ): JavaPairDStream[K, (Optional[V], W)] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag val joinResult = dstream.rightOuterJoin(other.dstream, numPartitions) joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)} } /** * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and - * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control * the partitioning of each RDD. */ def rightOuterJoin[W]( other: JavaPairDStream[K, W], partitioner: Partitioner ): JavaPairDStream[K, (Optional[V], W)] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag val joinResult = dstream.rightOuterJoin(other.dstream, partitioner) joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)} } @@ -748,8 +732,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( new JavaDStream[(K, V)](dstream) } - override val classTag: ClassTag[(K, V)] = - implicitly[ClassTag[Tuple2[_, _]]].asInstanceOf[ClassTag[Tuple2[K, V]]] + override val classTag: ClassTag[(K, V)] = fakeClassTag } object JavaPairDStream { @@ -758,10 +741,8 @@ object JavaPairDStream { } def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = { - implicit val cmk: ClassTag[K] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] - implicit val cmv: ClassTag[V] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + implicit val cmk: ClassTag[K] = fakeClassTag + implicit val cmv: ClassTag[V] = fakeClassTag new JavaPairDStream[K, V](dstream.dstream) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 2268160dccc1f..c48d754e439e9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -187,7 +187,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { converter: JFunction[InputStream, java.lang.Iterable[T]], storageLevel: StorageLevel) : JavaDStream[T] = { - def fn = (x: InputStream) => converter.apply(x).toIterator + def fn = (x: InputStream) => converter.call(x).toIterator implicit val cmt: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] ssc.socketStream(hostname, port, fn, storageLevel) @@ -406,7 +406,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to - * a JavaPairRDD using [[org.apache.spark.api.java.JavaPairRDD]].fromJavaRDD(). + * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ def transform[T]( dstreams: JList[JavaDStream[_]], @@ -429,9 +429,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream(). * In the transform function, convert the JavaRDD corresponding to that JavaDStream to - * a JavaPairRDD using [[org.apache.spark.api.java.JavaPairRDD]].fromJavaRDD(). + * a JavaPairRDD using org.apache.spark.api.java.JavaPairRDD.fromJavaRDD(). */ - def transform[K, V]( + def transformToPair[K, V]( dstreams: JList[JavaDStream[_]], transformFunc: JFunction2[JList[JavaRDD[_]], Time, JavaPairRDD[K, V]] ): JavaPairDStream[K, V] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index f3c58aede092a..2473496949360 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -65,7 +65,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new DStream by applying `groupByKey` on each RDD. The supplied - * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { val createCombiner = (v: V) => ArrayBuffer[V](v) @@ -95,7 +95,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are - * merged using the supplied reduce function. [[org.apache.spark.Partitioner]] is used to control + * merged using the supplied reduce function. org.apache.spark.Partitioner is used to control * the partitioning of each RDD. */ def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { @@ -376,7 +376,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of the key. - * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. * @param partitioner Partitioner for controlling the partitioning of each RDD in the new @@ -396,7 +396,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. - * [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then * corresponding state key-value pair will be eliminated. Note, that * this function may generate a different a tuple with a different key @@ -453,7 +453,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream. - * The supplied [[org.apache.spark.Partitioner]] is used to partition the generated RDDs. + * The supplied org.apache.spark.Partitioner is used to partition the generated RDDs. */ def cogroup[W: ClassTag]( other: DStream[(K, W)], @@ -483,7 +483,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. - * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD. + * The supplied org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ def join[W: ClassTag]( other: DStream[(K, W)], @@ -518,7 +518,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and - * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control * the partitioning of each RDD. */ def leftOuterJoin[W: ClassTag]( @@ -554,7 +554,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) /** * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and - * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control * the partitioning of each RDD. */ def rightOuterJoin[W: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala index 2b8cdb72b8d0e..a96e2924a0b44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala @@ -22,12 +22,20 @@ import scala.annotation.tailrec import java.io.OutputStream import java.util.concurrent.TimeUnit._ +import org.apache.spark.Logging + + private[streaming] -class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { - val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) - val CHUNK_SIZE = 8192 - var lastSyncTime = System.nanoTime - var bytesWrittenSinceSync: Long = 0 +class RateLimitedOutputStream(out: OutputStream, desiredBytesPerSec: Int) + extends OutputStream + with Logging { + + require(desiredBytesPerSec > 0) + + private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + private val CHUNK_SIZE = 8192 + private var lastSyncTime = System.nanoTime + private var bytesWrittenSinceSync = 0L override def write(b: Int) { waitToWrite(1) @@ -59,9 +67,9 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu @tailrec private def waitToWrite(numBytes: Int) { val now = System.nanoTime - val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) - val rate = bytesWrittenSinceSync.toDouble / elapsedSecs - if (rate < bytesPerSec) { + val elapsedNanosecs = math.max(now - lastSyncTime, 1) + val rate = bytesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs + if (rate < desiredBytesPerSec) { // It's okay to write; just update some variables and return bytesWrittenSinceSync += numBytes if (now > lastSyncTime + SYNC_INTERVAL) { @@ -71,13 +79,14 @@ class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends Outpu } } else { // Calculate how much time we should sleep to bring ourselves to the desired rate. - // Based on throttler in Kafka - // scalastyle:off - // (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) - // scalastyle:on - val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), - SECONDS) - if (sleepTime > 0) Thread.sleep(sleepTime) + val targetTimeInMillis = bytesWrittenSinceSync * 1000 / desiredBytesPerSec + val elapsedTimeInMillis = elapsedNanosecs / 1000000 + val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis + if (sleepTimeInMillis > 0) { + logTrace("Natural rate is " + rate + " per second but desired rate is " + + desiredBytesPerSec + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") + Thread.sleep(sleepTimeInMillis) + } waitToWrite(numBytes) } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 4fbbce9b8b90e..e93bf18b6d0b9 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -19,7 +19,6 @@ import scala.Tuple2; -import org.junit.After; import org.junit.Assert; import org.junit.Test; import java.io.*; @@ -30,7 +29,6 @@ import com.google.common.io.Files; import com.google.common.collect.Sets; -import org.apache.spark.SparkConf; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -38,6 +36,7 @@ import org.apache.spark.api.java.function.*; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaDStreamLike; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -45,6 +44,8 @@ // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { + + @SuppressWarnings("unchecked") @Test public void testCount() { List> inputData = Arrays.asList( @@ -64,6 +65,7 @@ public void testCount() { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testMap() { List> inputData = Arrays.asList( @@ -87,6 +89,7 @@ public Integer call(String s) throws Exception { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testWindow() { List> inputData = Arrays.asList( @@ -108,6 +111,7 @@ public void testWindow() { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testWindowWithSlideDuration() { List> inputData = Arrays.asList( @@ -132,6 +136,7 @@ public void testWindowWithSlideDuration() { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testFilter() { List> inputData = Arrays.asList( @@ -155,13 +160,16 @@ public Boolean call(String s) throws Exception { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testRepartitionMorePartitions() { List> inputData = Arrays.asList( Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 2); - JavaDStream repartitioned = stream.repartition(4); + JavaDStream stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 2); + JavaDStreamLike,JavaRDD> repartitioned = + stream.repartition(4); JavaTestUtils.attachTestOutputStream(repartitioned); List>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2); Assert.assertEquals(2, result.size()); @@ -172,13 +180,16 @@ public void testRepartitionMorePartitions() { } } + @SuppressWarnings("unchecked") @Test public void testRepartitionFewerPartitions() { List> inputData = Arrays.asList( Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 4); - JavaDStream repartitioned = stream.repartition(2); + JavaDStream stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 4); + JavaDStreamLike,JavaRDD> repartitioned = + stream.repartition(2); JavaTestUtils.attachTestOutputStream(repartitioned); List>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2); Assert.assertEquals(2, result.size()); @@ -188,6 +199,7 @@ public void testRepartitionFewerPartitions() { } } + @SuppressWarnings("unchecked") @Test public void testGlom() { List> inputData = Arrays.asList( @@ -206,6 +218,7 @@ public void testGlom() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testMapPartitions() { List> inputData = Arrays.asList( @@ -217,36 +230,38 @@ public void testMapPartitions() { Arrays.asList("YANKEESRED SOCKS")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream mapped = stream.mapPartitions(new FlatMapFunction, String>() { - @Override - public Iterable call(Iterator in) { - String out = ""; - while (in.hasNext()) { - out = out + in.next().toUpperCase(); - } - return Lists.newArrayList(out); - } - }); + JavaDStream mapped = stream.mapPartitions( + new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator in) { + String out = ""; + while (in.hasNext()) { + out = out + in.next().toUpperCase(); + } + return Lists.newArrayList(out); + } + }); JavaTestUtils.attachTestOutputStream(mapped); List> result = JavaTestUtils.runStreams(ssc, 2, 2); Assert.assertEquals(expected, result); } - private class IntegerSum extends Function2 { + private class IntegerSum implements Function2 { @Override public Integer call(Integer i1, Integer i2) throws Exception { return i1 + i2; } } - private class IntegerDifference extends Function2 { + private class IntegerDifference implements Function2 { @Override public Integer call(Integer i1, Integer i2) throws Exception { return i1 - i2; } } + @SuppressWarnings("unchecked") @Test public void testReduce() { List> inputData = Arrays.asList( @@ -267,6 +282,7 @@ public void testReduce() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testReduceByWindow() { List> inputData = Arrays.asList( @@ -289,6 +305,7 @@ public void testReduceByWindow() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testQueueStream() { List> expected = Arrays.asList( @@ -312,6 +329,7 @@ public void testQueueStream() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testTransform() { List> inputData = Arrays.asList( @@ -344,6 +362,7 @@ public Integer call(Integer i) throws Exception { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testVariousTransform() { // tests whether all variations of transform can be called from Java @@ -373,7 +392,7 @@ public JavaRDD call(JavaRDD in) throws Exception { } ); - JavaPairDStream transformed3 = stream.transform( + JavaPairDStream transformed3 = stream.transformToPair( new Function, JavaPairRDD>() { @Override public JavaPairRDD call(JavaRDD in) throws Exception { return null; @@ -381,7 +400,7 @@ public JavaRDD call(JavaRDD in) throws Exception { } ); - JavaPairDStream transformed4 = stream.transform( + JavaPairDStream transformed4 = stream.transformToPair( new Function2, Time, JavaPairRDD>() { @Override public JavaPairRDD call(JavaRDD in, Time time) throws Exception { return null; @@ -405,7 +424,7 @@ public JavaRDD call(JavaRDD in) throws Exception { } ); - JavaPairDStream pairTransformed3 = pairStream.transform( + JavaPairDStream pairTransformed3 = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { return null; @@ -413,7 +432,7 @@ public JavaRDD call(JavaRDD in) throws Exception { } ); - JavaPairDStream pairTransformed4 = pairStream.transform( + JavaPairDStream pairTransformed4 = pairStream.transformToPair( new Function2, Time, JavaPairRDD>() { @Override public JavaPairRDD call(JavaPairRDD in, Time time) throws Exception { return null; @@ -423,6 +442,7 @@ public JavaRDD call(JavaRDD in) throws Exception { } + @SuppressWarnings("unchecked") @Test public void testTransformWith() { List>> stringStringKVStream1 = Arrays.asList( @@ -462,7 +482,7 @@ public void testTransformWith() { ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - JavaPairDStream> joined = pairStream1.transformWith( + JavaPairDStream> joined = pairStream1.transformWithToPair( pairStream2, new Function3< JavaPairRDD, @@ -492,6 +512,7 @@ public JavaPairRDD> call( } + @SuppressWarnings("unchecked") @Test public void testVariousTransformWith() { // tests whether all variations of transformWith can be called from Java @@ -530,7 +551,7 @@ public JavaRDD call(JavaRDD rdd1, JavaPairRDD } ); - JavaPairDStream transformed3 = stream1.transformWith( + JavaPairDStream transformed3 = stream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override @@ -540,7 +561,7 @@ public JavaPairRDD call(JavaRDD rdd1, JavaRDD r } ); - JavaPairDStream transformed4 = stream1.transformWith( + JavaPairDStream transformed4 = stream1.transformWithToPair( pairStream1, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override @@ -570,7 +591,7 @@ public JavaRDD call(JavaPairRDD rdd1, JavaPairRDD pairTransformed3 = pairStream1.transformWith( + JavaPairDStream pairTransformed3 = pairStream1.transformWithToPair( stream2, new Function3, JavaRDD, Time, JavaPairRDD>() { @Override @@ -580,7 +601,7 @@ public JavaPairRDD call(JavaPairRDD rdd1, JavaR } ); - JavaPairDStream pairTransformed4 = pairStream1.transformWith( + JavaPairDStream pairTransformed4 = pairStream1.transformWithToPair( pairStream2, new Function3, JavaPairRDD, Time, JavaPairRDD>() { @Override @@ -591,6 +612,7 @@ public JavaPairRDD call(JavaPairRDD rdd1, JavaP ); } + @SuppressWarnings("unchecked") @Test public void testStreamingContextTransform(){ List> stream1input = Arrays.asList( @@ -634,7 +656,7 @@ public JavaRDD call(List> listOfRDDs, Time time) { List> listOfDStreams2 = Arrays.>asList(stream1, stream2, pairStream1.toJavaDStream()); - JavaPairDStream> transformed2 = ssc.transform( + JavaPairDStream> transformed2 = ssc.transformToPair( listOfDStreams2, new Function2>, Time, JavaPairRDD>>() { public JavaPairRDD> call(List> listOfRDDs, Time time) { @@ -649,7 +671,7 @@ public Tuple2 call(Integer i) throws Exception { return new Tuple2(i, i); } }; - return rdd1.union(rdd2).map(mapToTuple).join(prdd3); + return rdd1.union(rdd2).mapToPair(mapToTuple).join(prdd3); } } ); @@ -658,6 +680,7 @@ public Tuple2 call(Integer i) throws Exception { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testFlatMap() { List> inputData = Arrays.asList( @@ -683,6 +706,7 @@ public Iterable call(String x) { assertOrderInvariantEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testPairFlatMap() { List> inputData = Arrays.asList( @@ -718,22 +742,24 @@ public void testPairFlatMap() { new Tuple2(9, "s"))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction() { - @Override - public Iterable> call(String in) throws Exception { - List> out = Lists.newArrayList(); - for (String letter: in.split("(?!^)")) { - out.add(new Tuple2(in.length(), letter)); + JavaPairDStream flatMapped = stream.flatMapToPair( + new PairFlatMapFunction() { + @Override + public Iterable> call(String in) throws Exception { + List> out = Lists.newArrayList(); + for (String letter: in.split("(?!^)")) { + out.add(new Tuple2(in.length(), letter)); + } + return out; } - return out; - } - }); + }); JavaTestUtils.attachTestOutputStream(flatMapped); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testUnion() { List> inputData1 = Arrays.asList( @@ -778,6 +804,7 @@ public static > void assertOrderInvariantEquals( // PairDStream Functions + @SuppressWarnings("unchecked") @Test public void testPairFilter() { List> inputData = Arrays.asList( @@ -789,7 +816,7 @@ public void testPairFilter() { Arrays.asList(new Tuple2("yankees", 7))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream pairStream = stream.map( + JavaPairDStream pairStream = stream.mapToPair( new PairFunction() { @Override public Tuple2 call(String in) throws Exception { @@ -810,7 +837,8 @@ public Boolean call(Tuple2 in) throws Exception { Assert.assertEquals(expected, result); } - List>> stringStringKVStream = Arrays.asList( + @SuppressWarnings("unchecked") + private List>> stringStringKVStream = Arrays.asList( Arrays.asList(new Tuple2("california", "dodgers"), new Tuple2("california", "giants"), new Tuple2("new york", "yankees"), @@ -820,7 +848,8 @@ public Boolean call(Tuple2 in) throws Exception { new Tuple2("new york", "rangers"), new Tuple2("new york", "islanders"))); - List>> stringIntKVStream = Arrays.asList( + @SuppressWarnings("unchecked") + private List>> stringIntKVStream = Arrays.asList( Arrays.asList( new Tuple2("california", 1), new Tuple2("california", 3), @@ -832,6 +861,7 @@ public Boolean call(Tuple2 in) throws Exception { new Tuple2("new york", 3), new Tuple2("new york", 1))); + @SuppressWarnings("unchecked") @Test public void testPairMap() { // Maps pair -> pair of different type List>> inputData = stringIntKVStream; @@ -850,7 +880,7 @@ public void testPairMap() { // Maps pair -> pair of different type JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.map( + JavaPairDStream reversed = pairStream.mapToPair( new PairFunction, Integer, String>() { @Override public Tuple2 call(Tuple2 in) throws Exception { @@ -864,6 +894,7 @@ public Tuple2 call(Tuple2 in) throws Exception Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testPairMapPartitions() { // Maps pair -> pair of different type List>> inputData = stringIntKVStream; @@ -882,7 +913,7 @@ public void testPairMapPartitions() { // Maps pair -> pair of different type JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream reversed = pairStream.mapPartitions( + JavaPairDStream reversed = pairStream.mapPartitionsToPair( new PairFlatMapFunction>, Integer, String>() { @Override public Iterable> call(Iterator> in) throws Exception { @@ -901,6 +932,7 @@ public Iterable> call(Iterator> Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testPairMap2() { // Maps pair -> single List>> inputData = stringIntKVStream; @@ -925,6 +957,7 @@ public Integer call(Tuple2 in) throws Exception { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair List>> inputData = Arrays.asList( @@ -950,7 +983,7 @@ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream flatMapped = pairStream.flatMap( + JavaPairDStream flatMapped = pairStream.flatMapToPair( new PairFlatMapFunction, Integer, String>() { @Override public Iterable> call(Tuple2 in) throws Exception { @@ -967,6 +1000,7 @@ public Iterable> call(Tuple2 in) throws Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testPairGroupByKey() { List>> inputData = stringStringKVStream; @@ -989,6 +1023,7 @@ public void testPairGroupByKey() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testPairReduceByKey() { List>> inputData = stringIntKVStream; @@ -1013,6 +1048,7 @@ public void testPairReduceByKey() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testCombineByKey() { List>> inputData = stringIntKVStream; @@ -1043,6 +1079,7 @@ public Integer call(Integer i) throws Exception { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testCountByValue() { List> inputData = Arrays.asList( @@ -1068,6 +1105,7 @@ public void testCountByValue() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testGroupByKeyAndWindow() { List>> inputData = stringIntKVStream; @@ -1113,6 +1151,7 @@ private Tuple2> convert(Tuple2> t return new Tuple2>(tuple._1(), new HashSet(tuple._2())); } + @SuppressWarnings("unchecked") @Test public void testReduceByKeyAndWindow() { List>> inputData = stringIntKVStream; @@ -1136,6 +1175,7 @@ public void testReduceByKeyAndWindow() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testUpdateStateByKey() { List>> inputData = stringIntKVStream; @@ -1171,6 +1211,7 @@ public Optional call(List values, Optional state) { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testReduceByKeyAndWindowWithInverse() { List>> inputData = stringIntKVStream; @@ -1187,13 +1228,15 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream reduceWindowed = - pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000)); + pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testCountByValueAndWindow() { List> inputData = Arrays.asList( @@ -1227,6 +1270,7 @@ public void testCountByValueAndWindow() { Assert.assertEquals(expected, unorderedResult); } + @SuppressWarnings("unchecked") @Test public void testPairTransform() { List>> inputData = Arrays.asList( @@ -1257,7 +1301,7 @@ public void testPairTransform() { ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream sorted = pairStream.transform( + JavaPairDStream sorted = pairStream.transformToPair( new Function, JavaPairRDD>() { @Override public JavaPairRDD call(JavaPairRDD in) throws Exception { @@ -1271,6 +1315,7 @@ public JavaPairRDD call(JavaPairRDD in) thro Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testPairToNormalRDDTransform() { List>> inputData = Arrays.asList( @@ -1312,6 +1357,8 @@ public Integer call(Tuple2 in) { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test public void testMapValues() { List>> inputData = stringStringKVStream; @@ -1342,6 +1389,7 @@ public String call(String s) throws Exception { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testFlatMapValues() { List>> inputData = stringStringKVStream; @@ -1386,6 +1434,7 @@ public Iterable call(String in) { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testCoGroup() { List>> stringStringKVStream1 = Arrays.asList( @@ -1429,6 +1478,7 @@ public void testCoGroup() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testJoin() { List>> stringStringKVStream1 = Arrays.asList( @@ -1472,6 +1522,7 @@ public void testJoin() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testLeftOuterJoin() { List>> stringStringKVStream1 = Arrays.asList( @@ -1503,6 +1554,7 @@ public void testLeftOuterJoin() { Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") @Test public void testCheckpointMasterRecovery() throws InterruptedException { List> inputData = Arrays.asList( @@ -1541,7 +1593,8 @@ public Integer call(String s) throws Exception { } - /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD + /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD + @SuppressWarnings("unchecked") @Test public void testCheckpointofIndividualStream() throws InterruptedException { List> inputData = Arrays.asList( @@ -1580,17 +1633,16 @@ public void testSocketTextStream() { @Test public void testSocketString() { - class Converter extends Function> { - public Iterable call(InputStream in) { + + class Converter implements Function> { + public Iterable call(InputStream in) throws IOException { BufferedReader reader = new BufferedReader(new InputStreamReader(in)); List out = new ArrayList(); - try { - while (true) { - String line = reader.readLine(); - if (line == null) { break; } - out.add(line); - } - } catch (IOException e) { } + while (true) { + String line = reader.readLine(); + if (line == null) { break; } + out.add(line); + } return out; } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 15f13d5b19946..7d18a0fcf7ba8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.util -import org.scalatest.FunSuite import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ +import org.scalatest.FunSuite + class RateLimitedOutputStreamSuite extends FunSuite { private def benchmark[U](f: => U): Long = { @@ -32,9 +33,11 @@ class RateLimitedOutputStreamSuite extends FunSuite { test("write") { val underlying = new ByteArrayOutputStream val data = "X" * 41000 - val stream = new RateLimitedOutputStream(underlying, 10000) + val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } - assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4) - assert(underlying.toString("UTF-8") == data) + + // We accept anywhere from 4.0 to 4.99999 seconds since the value is rounded down. + assert(SECONDS.convert(elapsedNs, NANOSECONDS) === 4) + assert(underlying.toString("UTF-8") === data) } } diff --git a/tools/pom.xml b/tools/pom.xml index a27f0db6e5628..11433e596f5b0 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml @@ -28,7 +28,21 @@ spark-tools_2.10 jar Spark Project Tools - http://spark.incubator.apache.org/ + http://spark.apache.org/ + + + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index b026128980cb8..d0aeaceb0d23c 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -20,10 +20,24 @@ org.apache.spark yarn-parent_2.10 - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml + + + + yarn-alpha + + + org.apache.avro + avro + + + + + org.apache.spark spark-yarn-alpha_2.10 jar diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2e46d750c4a38..910484ed5432a 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -36,7 +35,8 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, @@ -61,9 +61,14 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) private var isLastAMRetry: Boolean = true - // Default to numWorkers * 2, with minimum of 3 - private val maxNumWorkerFailures = sparkConf.getInt("spark.yarn.max.worker.failures", - math.max(args.numWorkers * 2, 3)) + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + + private var registered = false + + private val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse( + SparkContext.SPARK_UNKNOWN_USER) def run() { // Setup the directories so things go to yarn approved directories rather @@ -74,6 +79,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // other spark processes running on the same box System.setProperty("spark.ui.port", "0") + // when running the AM, the Spark master is always "yarn-cluster" + System.setProperty("spark.master", "yarn-cluster") + // Use priority 30 as its higher then HDFS. Its same priority as MapReduce is using. ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) @@ -81,27 +89,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts resourceManager = registerWithResourceManager() - // Workaround until hadoop moves to something which has - // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) - // ignore result. - // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times - // Hence args.workerCores = numCore disabled above. Any better option? - - // Compute number of threads for akka - //val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - //if (minimumMemory > 0) { - // val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD - // val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - // if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore - // } - //} - // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + // setup AmIpFilter for the SparkUI - do this before we start the UI + addAmIpFilter() ApplicationMaster.register(this) + + // Call this to force generation of secret so it gets populated into the + // hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the executor containers. + val securityMgr = new SecurityManager(sparkConf) + // Start the user's JAR userThread = startUserClass() @@ -110,10 +107,15 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, waitForSparkContextInitialized() // Do this after spark master is up and SparkContext is created so that we can register UI Url - val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + synchronized { + if (!isFinished) { + registerApplicationMaster() + registered = true + } + } // Allocate all containers - allocateWorkers() + allocateExecutors() // Wait for the user class to Finish userThread.join() @@ -121,6 +123,20 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.exit(0) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + System.setProperty("spark.ui.filters", amFilter) + val proxy = YarnConfiguration.getProxyHostAndPort(conf) + val parts : Array[String] = proxy.split(":") + val uriBase = "http://" + proxy + + System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", + params) + } + /** Get the Yarn approved local directories. */ private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -173,7 +189,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, false /* initialize */ , Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) val t = new Thread { - override def run() { + override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () => var successed = false try { // Copy @@ -199,7 +215,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, t } - // this need to happen before allocateWorkers + // this need to happen before allocateExecutors private def waitForSparkContextInitialized() { logInfo("Waiting for spark context initialization") try { @@ -208,7 +224,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, var count = 0 val waitTime = 10000L val numTries = sparkConf.getInt("spark.yarn.ApplicationMaster.waitTries", 10) - while (ApplicationMaster.sparkContextRef.get() == null && count < numTries) { + while (ApplicationMaster.sparkContextRef.get() == null && count < numTries + && !isFinished) { logInfo("Waiting for spark context initialization ... " + count) count = count + 1 ApplicationMaster.sparkContextRef.wait(waitTime) @@ -243,21 +260,21 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } - private def allocateWorkers() { + private def allocateExecutors() { try { - logInfo("Allocating " + args.numWorkers + " workers.") + logInfo("Allocating " + args.numExecutors + " executors.") // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure // Exists the loop if the user thread exits. - while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) { - if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of worker failures reached") + "max number of executor failures reached") } yarnAllocator.allocateContainers( - math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) ApplicationMaster.incrementAllocatorLoop(1) Thread.sleep(100) } @@ -266,7 +283,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // so that the loop in ApplicationMaster#sparkContextInitialized() breaks. ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) } - logInfo("All workers have launched.") + logInfo("All executors have launched.") // Launch a progress reporter thread, else the app will get killed after expiration // (def: 10mins) timeout. @@ -292,15 +309,15 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { while (userThread.isAlive) { - if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of worker failures reached") + "max number of executor failures reached") } - val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - if (missingWorkerCount > 0) { + val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning + if (missingExecutorCount > 0) { logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingWorkerCount)) - yarnAllocator.allocateContainers(missingWorkerCount) + format(missingExecutorCount)) + yarnAllocator.allocateContainers(missingExecutorCount) } else sendProgress() Thread.sleep(sleepTime) @@ -341,17 +358,19 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, return } isFinished = true + + logInfo("finishApplicationMaster with " + status) + if (registered) { + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + finishReq.setFinishApplicationStatus(status) + finishReq.setDiagnostics(diagnostics) + // Set tracking url to empty since we don't have a history server. + finishReq.setTrackingUrl("") + resourceManager.finishApplicationMaster(finishReq) + } } - - logInfo("finishApplicationMaster with " + status) - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(appAttemptId) - finishReq.setFinishApplicationStatus(status) - finishReq.setDiagnostics(diagnostics) - // Set tracking url to empty since we don't have a history server. - finishReq.setTrackingUrl("") - resourceManager.finishApplicationMaster(finishReq) } /** diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala similarity index 89% rename from yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala rename to yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 138c27910b0b0..7b0e020263835 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -29,12 +29,12 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo -class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, sparkConf: SparkConf) +class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sparkConf: SparkConf) extends Logging { def this(args: ApplicationMasterArguments, sparkConf: SparkConf) = this(args, new Configuration(), sparkConf) @@ -50,8 +50,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var yarnAllocator: YarnAllocationHandler = _ private var driverClosed:Boolean = false + val securityManager = new SecurityManager(sparkConf) val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf)._1 + conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ // This actor just working as a monitor to watch on Driver Actor. @@ -88,7 +89,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() if (minimumMemory > 0) { - val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val mem = args.executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) if (numCore > 0) { @@ -101,7 +102,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar waitForSparkMaster() // Allocate all containers - allocateWorkers() + allocateExecutors() // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. @@ -110,6 +111,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) @@ -197,7 +199,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar } - private def allocateWorkers() { + private def allocateExecutors() { // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = @@ -206,16 +208,16 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, preferredNodeLocationData, sparkConf) - logInfo("Allocating " + args.numWorkers + " workers.") + logInfo("Allocating " + args.numExecutors + " executors.") // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - while ((yarnAllocator.getNumWorkersRunning < args.numWorkers) && (!driverClosed)) { - yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + yarnAllocator.allocateContainers(math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) Thread.sleep(100) } - logInfo("All workers have launched.") + logInfo("All executors have launched.") } @@ -226,10 +228,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar val t = new Thread { override def run() { while (!driverClosed) { - val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - if (missingWorkerCount > 0) { - logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") - yarnAllocator.allocateContainers(missingWorkerCount) + val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning + if (missingExecutorCount > 0) { + logInfo("Allocating " + missingExecutorCount + " containers to make up for (potentially ?) lost containers") + yarnAllocator.allocateContainers(missingExecutorCount) } else sendProgress() Thread.sleep(sleepTime) @@ -262,9 +264,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar } -object WorkerLauncher { +object ExecutorLauncher { def main(argStrings: Array[String]) { val args = new ApplicationMasterArguments(argStrings) - new WorkerLauncher(args).run() + new ExecutorLauncher(args).run() } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala similarity index 93% rename from yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala rename to yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 8c686e393f4f8..981e8b05f602d 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -38,16 +38,16 @@ import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} import org.apache.spark.{SparkConf, Logging} -class WorkerRunnable( +class ExecutorRunnable( container: Container, conf: Configuration, spConf: SparkConf, masterAddress: String, slaveId: String, hostname: String, - workerMemory: Int, - workerCores: Int) - extends Runnable with WorkerRunnableUtil with Logging { + executorMemory: Int, + executorCores: Int) + extends Runnable with ExecutorRunnableUtil with Logging { var rpc: YarnRPC = YarnRPC.create(conf) var cm: ContainerManager = _ @@ -55,7 +55,7 @@ class WorkerRunnable( val yarnConf: YarnConfiguration = new YarnConfiguration(conf) def run = { - logInfo("Starting Worker Container") + logInfo("Starting Executor Container") cm = connectToCM startContainer } @@ -81,8 +81,8 @@ class WorkerRunnable( credentials.writeTokenStorageToStream(dob) ctx.setContainerTokens(ByteBuffer.wrap(dob.getData())) - val commands = prepareCommand(masterAddress, slaveId, hostname, workerMemory, workerCores) - logInfo("Setting up worker with commands: " + commands) + val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores) + logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) // Send the start request to the ContainerManager diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index e91257be8ed00..2056667af50cb 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -58,9 +58,9 @@ private[yarn] class YarnAllocationHandler( val conf: Configuration, val resourceManager: AMRMProtocol, val appAttemptId: ApplicationAttemptId, - val maxWorkers: Int, - val workerMemory: Int, - val workerCores: Int, + val maxExecutors: Int, + val executorMemory: Int, + val executorCores: Int, val preferredHostToCount: Map[String, Int], val preferredRackToCount: Map[String, Int], val sparkConf: SparkConf) @@ -84,39 +84,39 @@ private[yarn] class YarnAllocationHandler( // Containers to be released in next request to RM private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] - private val numWorkersRunning = new AtomicInteger() - // Used to generate a unique id per worker - private val workerIdCounter = new AtomicInteger() + private val numExecutorsRunning = new AtomicInteger() + // Used to generate a unique id per executor + private val executorIdCounter = new AtomicInteger() private val lastResponseId = new AtomicInteger() - private val numWorkersFailed = new AtomicInteger() + private val numExecutorsFailed = new AtomicInteger() - def getNumWorkersRunning: Int = numWorkersRunning.intValue + def getNumExecutorsRunning: Int = numExecutorsRunning.intValue - def getNumWorkersFailed: Int = numWorkersFailed.intValue + def getNumExecutorsFailed: Int = numExecutorsFailed.intValue def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + container.getResource.getMemory >= (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) } - def allocateContainers(workersToRequest: Int) { + def allocateContainers(executorsToRequest: Int) { // We need to send the request only once from what I understand ... but for now, not modifying // this much. // Keep polling the Resource Manager for containers - val amResp = allocateWorkerResources(workersToRequest).getAMResponse + val amResp = allocateExecutorResources(executorsToRequest).getAMResponse val _allocatedContainers = amResp.getAllocatedContainers() if (_allocatedContainers.size > 0) { logDebug(""" Allocated containers: %d - Current worker count: %d + Current executor count: %d Containers released: %s Containers to be released: %s Cluster resources: %s """.format( _allocatedContainers.size, - numWorkersRunning.get(), + numExecutorsRunning.get(), releasedContainerList, pendingReleaseContainers, amResp.getAvailableResources)) @@ -221,59 +221,59 @@ private[yarn] class YarnAllocationHandler( // Run each of the allocated containers for (container <- allocatedContainers) { - val numWorkersRunningNow = numWorkersRunning.incrementAndGet() - val workerHostname = container.getNodeId.getHost + val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() + val executorHostname = container.getNodeId.getHost val containerId = container.getId assert( - container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + container.getResource.getMemory >= (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) - if (numWorkersRunningNow > maxWorkers) { + if (numExecutorsRunningNow > maxExecutors) { logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, workerHostname)) + containers for it.""".format(containerId, executorHostname)) releasedContainerList.add(containerId) // reset counter back to old value. - numWorkersRunning.decrementAndGet() + numExecutorsRunning.decrementAndGet() } else { // Deallocate + allocate can result in reusing id's wrongly - so use a different counter - // (workerIdCounter) - val workerId = workerIdCounter.incrementAndGet().toString + // (executorIdCounter) + val executorId = executorIdCounter.incrementAndGet().toString val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - logInfo("launching container on " + containerId + " host " + workerHostname) + logInfo("launching container on " + containerId + " host " + executorHostname) // Just to be safe, simply remove it from pendingReleaseContainers. // Should not be there, but .. pendingReleaseContainers.remove(containerId) - val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + val rack = YarnAllocationHandler.lookupRack(conf, executorHostname) allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]()) containerSet += containerId - allocatedContainerToHostMap.put(containerId, workerHostname) + allocatedContainerToHostMap.put(containerId, executorHostname) if (rack != null) { allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) } } new Thread( - new WorkerRunnable(container, conf, sparkConf, driverUrl, workerId, - workerHostname, workerMemory, workerCores) + new ExecutorRunnable(container, conf, sparkConf, driverUrl, executorId, + executorHostname, executorMemory, executorCores) ).start() } } logDebug(""" Finished processing %d containers. - Current number of workers running: %d, + Current number of executors running: %d, releasedContainerList: %s, pendingReleaseContainers: %s """.format( allocatedContainers.size, - numWorkersRunning.get(), + numExecutorsRunning.get(), releasedContainerList, pendingReleaseContainers)) } @@ -292,7 +292,7 @@ private[yarn] class YarnAllocationHandler( } else { // Simply decrement count - next iteration of ReporterThread will take care of allocating. - numWorkersRunning.decrementAndGet() + numExecutorsRunning.decrementAndGet() logInfo("Completed container %s (state: %s, exit status: %s)".format( containerId, completedContainer.getState, @@ -302,7 +302,7 @@ private[yarn] class YarnAllocationHandler( // now I think its ok as none of the containers are expected to exit if (completedContainer.getExitStatus() != 0) { logInfo("Container marked as failed: " + containerId) - numWorkersFailed.incrementAndGet() + numExecutorsFailed.incrementAndGet() } } @@ -332,12 +332,12 @@ private[yarn] class YarnAllocationHandler( } logDebug(""" Finished processing %d completed containers. - Current number of workers running: %d, + Current number of executors running: %d, releasedContainerList: %s, pendingReleaseContainers: %s """.format( completedContainers.size, - numWorkersRunning.get(), + numExecutorsRunning.get(), releasedContainerList, pendingReleaseContainers)) } @@ -387,18 +387,18 @@ private[yarn] class YarnAllocationHandler( retval } - private def allocateWorkerResources(numWorkers: Int): AllocateResponse = { + private def allocateExecutorResources(numExecutors: Int): AllocateResponse = { var resourceRequests: List[ResourceRequest] = null // default. - if (numWorkers <= 0 || preferredHostToCount.isEmpty) { - logDebug("numWorkers: " + numWorkers + ", host preferences: " + preferredHostToCount.isEmpty) + if (numExecutors <= 0 || preferredHostToCount.isEmpty) { + logDebug("numExecutors: " + numExecutors + ", host preferences: " + preferredHostToCount.isEmpty) resourceRequests = List( - createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)) + createResourceRequest(AllocationType.ANY, null, numExecutors, YarnAllocationHandler.PRIORITY)) } else { - // request for all hosts in preferred nodes and for numWorkers - + // request for all hosts in preferred nodes and for numExecutors - // candidates.size, request by default allocation policy. val hostContainerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest](preferredHostToCount.size) @@ -419,7 +419,7 @@ private[yarn] class YarnAllocationHandler( val anyContainerRequests: ResourceRequest = createResourceRequest( AllocationType.ANY, resource = null, - numWorkers, + numExecutors, YarnAllocationHandler.PRIORITY) val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest]( @@ -441,9 +441,9 @@ private[yarn] class YarnAllocationHandler( val releasedContainerList = createReleasedContainerList() req.addAllReleases(releasedContainerList) - if (numWorkers > 0) { - logInfo("Allocating %d worker containers with %d of memory each.".format(numWorkers, - workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + if (numExecutors > 0) { + logInfo("Allocating %d executor containers with %d of memory each.".format(numExecutors, + executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) } else { logDebug("Empty allocation req .. release : " + releasedContainerList) @@ -464,7 +464,7 @@ private[yarn] class YarnAllocationHandler( private def createResourceRequest( requestType: AllocationType.AllocationType, resource:String, - numWorkers: Int, + numExecutors: Int, priority: Int): ResourceRequest = { // If hostname specified, we need atleast two requests - node local and rack local. @@ -473,7 +473,7 @@ private[yarn] class YarnAllocationHandler( case AllocationType.HOST => { assert(YarnAllocationHandler.ANY_HOST != resource) val hostname = resource - val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority) + val nodeLocal = createResourceRequestImpl(hostname, numExecutors, priority) // Add to host->rack mapping YarnAllocationHandler.populateRackInfo(conf, hostname) @@ -482,10 +482,10 @@ private[yarn] class YarnAllocationHandler( } case AllocationType.RACK => { val rack = resource - createResourceRequestImpl(rack, numWorkers, priority) + createResourceRequestImpl(rack, numExecutors, priority) } case AllocationType.ANY => createResourceRequestImpl( - YarnAllocationHandler.ANY_HOST, numWorkers, priority) + YarnAllocationHandler.ANY_HOST, numExecutors, priority) case _ => throw new IllegalArgumentException( "Unexpected/unsupported request type: " + requestType) } @@ -493,13 +493,13 @@ private[yarn] class YarnAllocationHandler( private def createResourceRequestImpl( hostname:String, - numWorkers: Int, + numExecutors: Int, priority: Int): ResourceRequest = { val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) val memCapability = Records.newRecord(classOf[Resource]) // There probably is some overhead here, let's reserve a bit more memory. - memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + memCapability.setMemory(executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) rsrcRequest.setCapability(memCapability) val pri = Records.newRecord(classOf[Priority]) @@ -508,7 +508,7 @@ private[yarn] class YarnAllocationHandler( rsrcRequest.setHostName(hostname) - rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0)) + rsrcRequest.setNumContainers(java.lang.Math.max(numExecutors, 0)) rsrcRequest } @@ -560,9 +560,9 @@ object YarnAllocationHandler { conf, resourceManager, appAttemptId, - args.numWorkers, - args.workerMemory, - args.workerCores, + args.numExecutors, + args.executorMemory, + args.executorCores, Map[String, Int](), Map[String, Int](), sparkConf) @@ -582,9 +582,9 @@ object YarnAllocationHandler { conf, resourceManager, appAttemptId, - args.numWorkers, - args.workerMemory, - args.workerCores, + args.numExecutors, + args.executorMemory, + args.executorCores, hostToCount, rackToCount, sparkConf) @@ -594,9 +594,9 @@ object YarnAllocationHandler { conf: Configuration, resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, - maxWorkers: Int, - workerMemory: Int, - workerCores: Int, + maxExecutors: Int, + executorMemory: Int, + executorCores: Int, map: collection.Map[String, collection.Set[SplitInfo]], sparkConf: SparkConf): YarnAllocationHandler = { @@ -606,9 +606,9 @@ object YarnAllocationHandler { conf, resourceManager, appAttemptId, - maxWorkers, - workerMemory, - workerCores, + maxExecutors, + executorMemory, + executorCores, hostToCount, rackToCount, sparkConf) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index f76a5ddd39e90..25cc9016b10a6 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -24,9 +24,9 @@ class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var userArgs: Seq[String] = Seq[String]() - var workerMemory = 1024 - var workerCores = 1 - var numWorkers = 2 + var executorMemory = 1024 + var executorCores = 1 + var numExecutors = 2 parseArgs(args.toList) @@ -36,7 +36,8 @@ class ApplicationMasterArguments(val args: Array[String]) { var args = inputArgs while (! args.isEmpty) { - + // --num-workers, --worker-memory, and --worker-cores are deprecated since 1.0, + // the properties with executor in their names are preferred. args match { case ("--jar") :: value :: tail => userJar = value @@ -50,16 +51,16 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--num-workers") :: IntParam(value) :: tail => - numWorkers = value + case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => + numExecutors = value args = tail - case ("--worker-memory") :: IntParam(value) :: tail => - workerMemory = value + case ("--worker-memory" | "--executor-memory") :: IntParam(value) :: tail => + executorMemory = value args = tail - case ("--worker-cores") :: IntParam(value) :: tail => - workerCores = value + case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail => + executorCores = value args = tail case Nil => @@ -86,9 +87,9 @@ class ApplicationMasterArguments(val args: Array[String]) { " --class CLASS_NAME Name of your application's main class (required)\n" + " --args ARGS Arguments to be passed to your application's main class.\n" + " Mutliple invocations are possible, each will be passed in order.\n" + - " --num-workers NUM Number of workers to start (Default: 2)\n" + - " --worker-cores NUM Number of cores for the workers (Default: 1)\n" + - " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n") + " --num-executors NUM Number of executors to start (Default: 2)\n" + + " --executor-cores NUM Number of cores for the executors (Default: 1)\n" + + " --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)\n") System.exit(exitCode) } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 1419f215c78e5..c565f2dde24fc 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -33,9 +33,9 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { var userJar: String = null var userClass: String = null var userArgs: Seq[String] = Seq[String]() - var workerMemory = 1024 // MB - var workerCores = 1 - var numWorkers = 2 + var executorMemory = 1024 // MB + var executorCores = 1 + var numExecutors = 2 var amQueue = sparkConf.get("QUEUE", "default") var amMemory: Int = 512 // MB var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" @@ -67,24 +67,39 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { userArgsBuffer += value args = tail - case ("--master-class") :: value :: tail => + case ("--master-class" | "--am-class") :: value :: tail => + if (args(0) == "--master-class") { + println("--master-class is deprecated. Use --am-class instead.") + } amClass = value args = tail - case ("--master-memory") :: MemoryParam(value) :: tail => + case ("--master-memory" | "--driver-memory") :: MemoryParam(value) :: tail => + if (args(0) == "--master-memory") { + println("--master-memory is deprecated. Use --driver-memory instead.") + } amMemory = value args = tail - case ("--num-workers") :: IntParam(value) :: tail => - numWorkers = value + case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => + if (args(0) == "--num-workers") { + println("--num-workers is deprecated. Use --num-executors instead.") + } + numExecutors = value args = tail - case ("--worker-memory") :: MemoryParam(value) :: tail => - workerMemory = value + case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => + if (args(0) == "--worker-memory") { + println("--worker-memory is deprecated. Use --executor-memory instead.") + } + executorMemory = value args = tail - case ("--worker-cores") :: IntParam(value) :: tail => - workerCores = value + case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail => + if (args(0) == "--worker-cores") { + println("--worker-cores is deprecated. Use --executor-cores instead.") + } + executorCores = value args = tail case ("--queue") :: value :: tail => @@ -108,7 +123,7 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { args = tail case Nil => - if (userJar == null || userClass == null) { + if (userClass == null) { printUsageAndExit(1) } @@ -129,15 +144,14 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { System.err.println( "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + - " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + " --class CLASS_NAME Name of your application's main class (required)\n" + " --args ARGS Arguments to be passed to your application's main class.\n" + " Mutliple invocations are possible, each will be passed in order.\n" + - " --num-workers NUM Number of workers to start (Default: 2)\n" + - " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + - " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" + - " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + - " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --num-executors NUM Number of executors to start (Default: 2)\n" + + " --executor-cores NUM Number of cores for the executors (Default: 1).\n" + + " --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)\n" + + " --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)\n" + " --name NAME The name of your application (Default: Spark)\n" + " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 2db5744be1a70..57e5761cba896 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -29,8 +29,10 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.mapred.Master +import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.net.NetUtils import org.apache.hadoop.security.UserGroupInformation +import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ @@ -68,12 +70,13 @@ trait ClientBase extends Logging { def validateArgs() = { Map( (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!", - (args.userJar == null) -> "Error: You must specify a user jar!", + ((args.userJar == null && args.amClass == classOf[ApplicationMaster].getName) -> + "Error: You must specify a user jar when running in standalone mode!"), (args.userClass == null) -> "Error: You must specify a user class!", - (args.numWorkers <= 0) -> "Error: You must specify at least 1 worker!", + (args.numExecutors <= 0) -> "Error: You must specify at least 1 executor!", (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" + "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD), - (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size" + + (args.executorMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Executor memory size" + "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString) ).foreach { case(cond, errStr) => if (cond) { @@ -92,9 +95,9 @@ trait ClientBase extends Logging { logInfo("Max mem capabililty of a single resource in this cluster " + maxMem) // If we have requested more then the clusters max for a single resource then exit. - if (args.workerMemory > maxMem) { - logError("Required worker memory (%d MB), is above the max threshold (%d MB) of this cluster.". - format(args.workerMemory, maxMem)) + if (args.executorMemory > maxMem) { + logError("Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster.". + format(args.executorMemory, maxMem)) System.exit(1) } val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD @@ -271,8 +274,9 @@ trait ClientBase extends Logging { ClientBase.populateClasspath(yarnConf, sparkConf, log4jConfLocalRes != null, env) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir + env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - // Set the environment variables to be passed on to the Workers. + // Set the environment variables to be passed on to the executors. distCacheMgr.setDistFilesEnv(env) distCacheMgr.setDistArchivesEnv(env) @@ -356,9 +360,9 @@ trait ClientBase extends Logging { " --class " + args.userClass + " --jar " + args.userJar + userArgsToString(args) + - " --worker-memory " + args.workerMemory + - " --worker-cores " + args.workerCores + - " --num-workers " + args.numWorkers + + " --executor-memory " + args.executorMemory + + " --executor-cores " + args.executorCores + + " --num-executors " + args.numExecutors + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") @@ -377,9 +381,50 @@ object ClientBase { // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { - for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) { + val classpathEntries = Option(conf.getStrings( + YarnConfiguration.YARN_APPLICATION_CLASSPATH)).getOrElse( + getDefaultYarnApplicationClasspath()) + for (c <- classpathEntries) { Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) } + + val mrClasspathEntries = Option(conf.getStrings( + "mapreduce.application.classpath")).getOrElse( + getDefaultMRApplicationClasspath()) + if (mrClasspathEntries != null) { + for (c <- mrClasspathEntries) { + Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim) + } + } + } + + def getDefaultYarnApplicationClasspath(): Array[String] = { + try { + val field = classOf[MRJobConfig].getField("DEFAULT_YARN_APPLICATION_CLASSPATH") + field.get(null).asInstanceOf[Array[String]] + } catch { + case err: NoSuchFieldError => null + case err: NoSuchFieldException => null + } + } + + /** + * In Hadoop 0.23, the MR application classpath comes with the YARN application + * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. + * So we need to use reflection to retrieve it. + */ + def getDefaultMRApplicationClasspath(): Array[String] = { + try { + val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH") + if (field.getType == classOf[String]) { + StringUtils.getStrings(field.get(null).asInstanceOf[String]) + } else { + field.get(null).asInstanceOf[Array[String]] + } + } catch { + case err: NoSuchFieldError => null + case err: NoSuchFieldException => null + } } def populateClasspath(conf: Configuration, sparkConf: SparkConf, addLog4j: Boolean, env: HashMap[String, String]) { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 535abbfb7f638..68cda0f1c9f8b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -46,10 +46,10 @@ class ClientDistributedCacheManager() extends Logging { /** * Add a resource to the list of distributed cache resources. This list can - * be sent to the ApplicationMaster and possibly the workers so that it can + * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. * Adds the LocalResource to the localResources HashMap passed in and saves - * the stats of the resources to they can be sent to the workers and verified. + * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala similarity index 95% rename from yarn/common/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnableUtil.scala rename to yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index bfa8f84bf7f85..da0a6f74efcd5 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -39,7 +39,7 @@ import org.apache.spark.{SparkConf, Logging} import org.apache.hadoop.yarn.conf.YarnConfiguration -trait WorkerRunnableUtil extends Logging { +trait ExecutorRunnableUtil extends Logging { val yarnConf: YarnConfiguration val sparkConf: SparkConf @@ -49,13 +49,13 @@ trait WorkerRunnableUtil extends Logging { masterAddress: String, slaveId: String, hostname: String, - workerMemory: Int, - workerCores: Int) = { + executorMemory: Int, + executorCores: Int) = { // Extra options for the JVM var JAVA_OPTS = "" // Set the JVM memory - val workerMemoryString = workerMemory + "m" - JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " " + val executorMemoryString = executorMemory + "m" + JAVA_OPTS += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " " if (env.isDefinedAt("SPARK_JAVA_OPTS")) { JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " } @@ -97,7 +97,7 @@ trait WorkerRunnableUtil extends Logging { val commands = List[String](javaCommand + " -server " + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. - // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in + // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in // an inconsistent state. // TODO: If the OOM is not recoverable by rescheduling it on different node, then do // 'something' to fail job ... akin to blacklisting trackers in mapred ? @@ -107,7 +107,7 @@ trait WorkerRunnableUtil extends Logging { masterAddress + " " + slaveId + " " + hostname + " " + - workerCores + + executorCores + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 2ba2366ead171..4c6e1dcd6dac3 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,17 +17,23 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.conf.Configuration +import org.apache.spark.deploy.SparkHadoopUtil /** * Contains util methods to interact with Hadoop from spark. */ class YarnSparkHadoopUtil extends SparkHadoopUtil { + override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { + dest.addCredentials(source.getCredentials()) + } + // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true. override def isYarnMode(): Boolean = { true } @@ -40,4 +46,24 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val jobCreds = conf.getCredentials() jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) } + + override def getCurrentUserCredentials(): Credentials = { + UserGroupInformation.getCurrentUser().getCredentials() + } + + override def addCurrentUserCredentials(creds: Credentials) { + UserGroupInformation.getCurrentUser().addCredentials(creds) + } + + override def addSecretKeyToUserCredentials(key: String, secret: String) { + val creds = new Credentials() + creds.addSecretKey(new Text(key), secret.getBytes()) + addCurrentUserCredentials(creds) + } + + override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { + val credentials = getCurrentUserCredentials() + if (credentials != null) credentials.getSecretKey(new Text(key)) else null + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala index 522e0a9ad7eeb..6b91e6b9eb899 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -25,7 +25,7 @@ import org.apache.spark.util.Utils /** * - * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. + * This scheduler launches executors through Yarn - by calling into Client to launch ExecutorLauncher as AM. */ private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends TaskSchedulerImpl(sc) { @@ -40,7 +40,7 @@ private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configur override def postStartHook() { - // The yarn application is running, but the worker might not yet ready + // The yarn application is running, but the executor might not yet ready // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt Thread.sleep(2000L) logInfo("YarnClientClusterScheduler.postStartHook done") diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 22e55e0c60647..d1f13e3c369ed 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -44,10 +44,6 @@ private[spark] class YarnClientSchedulerBackend( override def start() { super.start() - val userJar = System.getenv("SPARK_YARN_APP_JAR") - if (userJar == null) - throw new SparkException("env SPARK_YARN_APP_JAR is not set") - val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort @@ -55,22 +51,26 @@ private[spark] class YarnClientSchedulerBackend( val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( "--class", "notused", - "--jar", userJar, + "--jar", null, "--args", hostport, - "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher" + "--am-class", "org.apache.spark.deploy.yarn.ExecutorLauncher" ) // process any optional arguments, use the defaults already defined in ClientArguments // if things aren't specified - Map("--master-memory" -> "SPARK_MASTER_MEMORY", - "--num-workers" -> "SPARK_WORKER_INSTANCES", - "--worker-memory" -> "SPARK_WORKER_MEMORY", - "--worker-cores" -> "SPARK_WORKER_CORES", - "--queue" -> "SPARK_YARN_QUEUE", - "--name" -> "SPARK_YARN_APP_NAME", - "--files" -> "SPARK_YARN_DIST_FILES", - "--archives" -> "SPARK_YARN_DIST_ARCHIVES") - .foreach { case (optName, optParam) => addArg(optName, optParam, argsArrayBuf) } + Map("SPARK_MASTER_MEMORY" -> "--driver-memory", + "SPARK_DRIVER_MEMORY" -> "--driver-memory", + "SPARK_WORKER_INSTANCES" -> "--num-executors", + "SPARK_WORKER_MEMORY" -> "--executor-memory", + "SPARK_WORKER_CORES" -> "--executor-cores", + "SPARK_EXECUTOR_INSTANCES" -> "--num-executors", + "SPARK_EXECUTOR_MEMORY" -> "--executor-memory", + "SPARK_EXECUTOR_CORES" -> "--executor-cores", + "SPARK_YARN_QUEUE" -> "--queue", + "SPARK_YARN_APP_NAME" -> "--name", + "SPARK_YARN_DIST_FILES" -> "--files", + "SPARK_YARN_DIST_ARCHIVES" -> "--archives") + .foreach { case (optParam, optName) => addArg(optName, optParam, argsArrayBuf) } logDebug("ClientArguments called with: " + argsArrayBuf) val args = new ClientArguments(argsArrayBuf.toArray, conf) @@ -81,7 +81,7 @@ private[spark] class YarnClientSchedulerBackend( def waitForApp() { - // TODO : need a better way to find out whether the workers are ready or not + // TODO : need a better way to find out whether the executors are ready or not // maybe by resource usage report? while(true) { val report = client.getApplicationReport(appId) diff --git a/yarn/pom.xml b/yarn/pom.xml index e7eba36ba351b..35e31760c1f02 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml @@ -52,14 +52,6 @@ hadoop-client ${yarn.version} - - org.apache.avro - avro - - - org.apache.avro - avro-ipc - org.scalatest scalatest_${scala.binary.version} @@ -78,6 +70,15 @@ alpha + + + + + org.apache.avro + avro + + @@ -133,7 +134,7 @@ true - + @@ -146,7 +147,7 @@ - + diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index 7c312206d16d3..e7915d12aef63 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,10 +20,24 @@ org.apache.spark yarn-parent_2.10 - 1.0.0-incubating-SNAPSHOT + 1.0.0-SNAPSHOT ../pom.xml + + + + yarn-alpha + + + org.apache.avro + avro + + + + + org.apache.spark spark-yarn_2.10 jar diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4b777d5fa7a28..30735cbfdf26e 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.protocolrecords._ @@ -37,8 +36,10 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.apache.hadoop.yarn.webapp.util.WebAppUtils; -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -63,9 +64,14 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var isLastAMRetry: Boolean = true private var amClient: AMRMClient[ContainerRequest] = _ - // Default to numWorkers * 2, with minimum of 3 - private val maxNumWorkerFailures = sparkConf.getInt("spark.yarn.max.worker.failures", - math.max(args.numWorkers * 2, 3)) + // Default to numExecutors * 2, with minimum of 3 + private val maxNumExecutorFailures = sparkConf.getInt("spark.yarn.max.executor.failures", + sparkConf.getInt("spark.yarn.max.worker.failures", math.max(args.numExecutors * 2, 3))) + + private var registered = false + + private val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse( + SparkContext.SPARK_UNKNOWN_USER) def run() { // Setup the directories so things go to YARN approved directories rather @@ -76,6 +82,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // other spark processes running on the same box System.setProperty("spark.ui.port", "0") + // when running the AM, the Spark master is always "yarn-cluster" + System.setProperty("spark.master", "yarn-cluster") + // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using. ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) @@ -85,12 +94,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, amClient.init(yarnConf) amClient.start() - // Workaround until hadoop moves to something which has - // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) - // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + // setup AmIpFilter for the SparkUI - do this before we start the UI + addAmIpFilter() ApplicationMaster.register(this) + // Call this to force generation of secret so it gets populated into the + // hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the executor containers. + val securityMgr = new SecurityManager(sparkConf) + // Start the user's JAR userThread = startUserClass() @@ -99,10 +112,15 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, waitForSparkContextInitialized() // Do this after Spark master is up and SparkContext is created so that we can register UI Url. - val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + synchronized { + if (!isFinished) { + registerApplicationMaster() + registered = true + } + } // Allocate all containers - allocateWorkers() + allocateExecutors() // Wait for the user class to Finish userThread.join() @@ -110,6 +128,19 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.exit(0) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + System.setProperty("spark.ui.filters", amFilter) + val proxy = WebAppUtils.getProxyHostAndPort(conf) + val parts : Array[String] = proxy.split(":") + val uriBase = "http://" + proxy + + System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) + } + /** Get the Yarn approved local directories. */ private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -145,7 +176,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, false /* initialize */ , Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) val t = new Thread { - override def run() { + override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () => var successed = false try { // Copy @@ -171,7 +202,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, t } - // This need to happen before allocateWorkers() + // This need to happen before allocateExecutors() private def waitForSparkContextInitialized() { logInfo("Waiting for Spark context initialization") try { @@ -180,7 +211,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, var numTries = 0 val waitTime = 10000L val maxNumTries = sparkConf.getInt("spark.yarn.applicationMaster.waitTries", 10) - while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries) { + while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries + && !isFinished) { logInfo("Waiting for Spark context initialization ... " + numTries) numTries = numTries + 1 ApplicationMaster.sparkContextRef.wait(waitTime) @@ -215,18 +247,18 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } - private def allocateWorkers() { + private def allocateExecutors() { try { - logInfo("Allocating " + args.numWorkers + " workers.") + logInfo("Allocating " + args.numExecutors + " executors.") // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - yarnAllocator.addResourceRequests(args.numWorkers) + yarnAllocator.addResourceRequests(args.numExecutors) // Exits the loop if the user thread exits. - while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) { - if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of worker failures reached") + "max number of executor failures reached") } yarnAllocator.allocateResources() ApplicationMaster.incrementAllocatorLoop(1) @@ -237,7 +269,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // so that the loop in ApplicationMaster#sparkContextInitialized() breaks. ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) } - logInfo("All workers have launched.") + logInfo("All executors have launched.") // Launch a progress reporter thread, else the app will get killed after expiration // (def: 10mins) timeout. @@ -249,7 +281,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val schedulerInterval = sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) - // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) @@ -263,16 +294,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { while (userThread.isAlive) { - if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of worker failures reached") + "max number of executor failures reached") } - val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - + val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - yarnAllocator.getNumPendingAllocate - if (missingWorkerCount > 0) { + if (missingExecutorCount > 0) { logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingWorkerCount)) - yarnAllocator.addResourceRequests(missingWorkerCount) + format(missingExecutorCount)) + yarnAllocator.addResourceRequests(missingExecutorCount) } sendProgress() Thread.sleep(sleepTime) @@ -313,11 +344,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, return } isFinished = true - } - logInfo("finishApplicationMaster with " + status) - // Set tracking URL to empty since we don't have a history server. - amClient.unregisterApplicationMaster(status, "" /* appMessage */ , "" /* appTrackingUrl */) + logInfo("finishApplicationMaster with " + status) + if (registered) { + // Set tracking URL to empty since we don't have a history server. + amClient.unregisterApplicationMaster(status, "" /* appMessage */ , "" /* appTrackingUrl */) + } + } } /** diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala similarity index 89% rename from yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala rename to yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 40600f38e5e73..b697f103914fd 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -28,14 +28,14 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, sparkConf: SparkConf) +class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sparkConf: SparkConf) extends Logging { def this(args: ApplicationMasterArguments, sparkConf: SparkConf) = @@ -52,8 +52,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var amClient: AMRMClient[ContainerRequest] = _ + val securityManager = new SecurityManager(sparkConf) val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf)._1 + conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ // This actor just working as a monitor to watch on Driver Actor. @@ -92,7 +93,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar waitForSparkMaster() // Allocate all containers - allocateWorkers() + allocateExecutors() // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. @@ -105,6 +106,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar val interval = math.min(timeoutInterval / 2, schedulerInterval) reporterThread = launchReporterThread(interval) + // Wait for the reporter thread to Finish. reporterThread.join() @@ -173,7 +175,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar } - private def allocateWorkers() { + private def allocateExecutors() { // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = @@ -187,18 +189,18 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar preferredNodeLocationData, sparkConf) - logInfo("Allocating " + args.numWorkers + " workers.") + logInfo("Allocating " + args.numExecutors + " executors.") // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - yarnAllocator.addResourceRequests(args.numWorkers) - while ((yarnAllocator.getNumWorkersRunning < args.numWorkers) && (!driverClosed)) { + yarnAllocator.addResourceRequests(args.numExecutors) + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { yarnAllocator.allocateResources() Thread.sleep(100) } - logInfo("All workers have launched.") + logInfo("All executors have launched.") } @@ -209,12 +211,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar val t = new Thread { override def run() { while (!driverClosed) { - val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - + val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - yarnAllocator.getNumPendingAllocate - if (missingWorkerCount > 0) { + if (missingExecutorCount > 0) { logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingWorkerCount)) - yarnAllocator.addResourceRequests(missingWorkerCount) + format(missingExecutorCount)) + yarnAllocator.addResourceRequests(missingExecutorCount) } sendProgress() Thread.sleep(sleepTime) @@ -242,9 +244,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar } -object WorkerLauncher { +object ExecutorLauncher { def main(argStrings: Array[String]) { val args = new ApplicationMasterArguments(argStrings) - new WorkerLauncher(args).run() + new ExecutorLauncher(args).run() } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala similarity index 90% rename from yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala rename to yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index ab4a79be70485..53c403f7d0913 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -38,16 +38,16 @@ import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.{SparkConf, Logging} -class WorkerRunnable( +class ExecutorRunnable( container: Container, conf: Configuration, spConf: SparkConf, masterAddress: String, slaveId: String, hostname: String, - workerMemory: Int, - workerCores: Int) - extends Runnable with WorkerRunnableUtil with Logging { + executorMemory: Int, + executorCores: Int) + extends Runnable with ExecutorRunnableUtil with Logging { var rpc: YarnRPC = YarnRPC.create(conf) var nmClient: NMClient = _ @@ -55,7 +55,7 @@ class WorkerRunnable( val yarnConf: YarnConfiguration = new YarnConfiguration(conf) def run = { - logInfo("Starting Worker Container") + logInfo("Starting Executor Container") nmClient = NMClient.createNMClient() nmClient.init(yarnConf) nmClient.start() @@ -78,9 +78,9 @@ class WorkerRunnable( credentials.writeTokenStorageToStream(dob) ctx.setTokens(ByteBuffer.wrap(dob.getData())) - val commands = prepareCommand(masterAddress, slaveId, hostname, workerMemory, workerCores) + val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores) - logInfo("Setting up worker with commands: " + commands) + logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) // Send the start request to the ContainerManager diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 1ac61124cb028..e31c4060e8452 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -60,9 +60,9 @@ private[yarn] class YarnAllocationHandler( val conf: Configuration, val amClient: AMRMClient[ContainerRequest], val appAttemptId: ApplicationAttemptId, - val maxWorkers: Int, - val workerMemory: Int, - val workerCores: Int, + val maxExecutors: Int, + val executorMemory: Int, + val executorCores: Int, val preferredHostToCount: Map[String, Int], val preferredRackToCount: Map[String, Int], val sparkConf: SparkConf) @@ -89,20 +89,20 @@ private[yarn] class YarnAllocationHandler( // Number of container requests that have been sent to, but not yet allocated by the // ApplicationMaster. private val numPendingAllocate = new AtomicInteger() - private val numWorkersRunning = new AtomicInteger() - // Used to generate a unique id per worker - private val workerIdCounter = new AtomicInteger() + private val numExecutorsRunning = new AtomicInteger() + // Used to generate a unique id per executor + private val executorIdCounter = new AtomicInteger() private val lastResponseId = new AtomicInteger() - private val numWorkersFailed = new AtomicInteger() + private val numExecutorsFailed = new AtomicInteger() def getNumPendingAllocate: Int = numPendingAllocate.intValue - def getNumWorkersRunning: Int = numWorkersRunning.intValue + def getNumExecutorsRunning: Int = numExecutorsRunning.intValue - def getNumWorkersFailed: Int = numWorkersFailed.intValue + def getNumExecutorsFailed: Int = numExecutorsFailed.intValue def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + container.getResource.getMemory >= (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) } def releaseContainer(container: Container) { @@ -127,13 +127,13 @@ private[yarn] class YarnAllocationHandler( logDebug(""" Allocated containers: %d - Current worker count: %d + Current executor count: %d Containers released: %s Containers to-be-released: %s Cluster resources: %s """.format( allocatedContainers.size, - numWorkersRunning.get(), + numExecutorsRunning.get(), releasedContainerList, pendingReleaseContainers, allocateResponse.getAvailableResources)) @@ -240,64 +240,64 @@ private[yarn] class YarnAllocationHandler( // Run each of the allocated containers. for (container <- allocatedContainersToProcess) { - val numWorkersRunningNow = numWorkersRunning.incrementAndGet() - val workerHostname = container.getNodeId.getHost + val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() + val executorHostname = container.getNodeId.getHost val containerId = container.getId - val workerMemoryOverhead = (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) - assert(container.getResource.getMemory >= workerMemoryOverhead) + val executorMemoryOverhead = (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + assert(container.getResource.getMemory >= executorMemoryOverhead) - if (numWorkersRunningNow > maxWorkers) { + if (numExecutorsRunningNow > maxExecutors) { logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, workerHostname)) + containers for it.""".format(containerId, executorHostname)) releaseContainer(container) - numWorkersRunning.decrementAndGet() + numExecutorsRunning.decrementAndGet() } else { - val workerId = workerIdCounter.incrementAndGet().toString + val executorId = executorIdCounter.incrementAndGet().toString val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - logInfo("Launching container %s for on host %s".format(containerId, workerHostname)) + logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) // To be safe, remove the container from `pendingReleaseContainers`. pendingReleaseContainers.remove(containerId) - val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) + val rack = YarnAllocationHandler.lookupRack(conf, executorHostname) allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]()) containerSet += containerId - allocatedContainerToHostMap.put(containerId, workerHostname) + allocatedContainerToHostMap.put(containerId, executorHostname) if (rack != null) { allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) } } - logInfo("Launching WorkerRunnable. driverUrl: %s, workerHostname: %s".format(driverUrl, workerHostname)) - val workerRunnable = new WorkerRunnable( + logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format(driverUrl, executorHostname)) + val executorRunnable = new ExecutorRunnable( container, conf, sparkConf, driverUrl, - workerId, - workerHostname, - workerMemory, - workerCores) - new Thread(workerRunnable).start() + executorId, + executorHostname, + executorMemory, + executorCores) + new Thread(executorRunnable).start() } } logDebug(""" Finished allocating %s containers (from %s originally). - Current number of workers running: %d, + Current number of executors running: %d, releasedContainerList: %s, pendingReleaseContainers: %s """.format( allocatedContainersToProcess, allocatedContainers, - numWorkersRunning.get(), + numExecutorsRunning.get(), releasedContainerList, pendingReleaseContainers)) } @@ -314,9 +314,9 @@ private[yarn] class YarnAllocationHandler( // `pendingReleaseContainers`. pendingReleaseContainers.remove(containerId) } else { - // Decrement the number of workers running. The next iteration of the ApplicationMaster's + // Decrement the number of executors running. The next iteration of the ApplicationMaster's // reporting thread will take care of allocating. - numWorkersRunning.decrementAndGet() + numExecutorsRunning.decrementAndGet() logInfo("Completed container %s (state: %s, exit status: %s)".format( containerId, completedContainer.getState, @@ -326,7 +326,7 @@ private[yarn] class YarnAllocationHandler( // now I think its ok as none of the containers are expected to exit if (completedContainer.getExitStatus() != 0) { logInfo("Container marked as failed: " + containerId) - numWorkersFailed.incrementAndGet() + numExecutorsFailed.incrementAndGet() } } @@ -364,12 +364,12 @@ private[yarn] class YarnAllocationHandler( } logDebug(""" Finished processing %d completed containers. - Current number of workers running: %d, + Current number of executors running: %d, releasedContainerList: %s, pendingReleaseContainers: %s """.format( completedContainers.size, - numWorkersRunning.get(), + numExecutorsRunning.get(), releasedContainerList, pendingReleaseContainers)) } @@ -421,18 +421,18 @@ private[yarn] class YarnAllocationHandler( retval } - def addResourceRequests(numWorkers: Int) { + def addResourceRequests(numExecutors: Int) { val containerRequests: List[ContainerRequest] = - if (numWorkers <= 0 || preferredHostToCount.isEmpty) { - logDebug("numWorkers: " + numWorkers + ", host preferences: " + + if (numExecutors <= 0 || preferredHostToCount.isEmpty) { + logDebug("numExecutors: " + numExecutors + ", host preferences: " + preferredHostToCount.isEmpty) createResourceRequests( AllocationType.ANY, resource = null, - numWorkers, + numExecutors, YarnAllocationHandler.PRIORITY).toList } else { - // Request for all hosts in preferred nodes and for numWorkers - + // Request for all hosts in preferred nodes and for numExecutors - // candidates.size, request by default allocation policy. val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size) for ((candidateHost, candidateCount) <- preferredHostToCount) { @@ -452,7 +452,7 @@ private[yarn] class YarnAllocationHandler( val anyContainerRequests = createResourceRequests( AllocationType.ANY, resource = null, - numWorkers, + numExecutors, YarnAllocationHandler.PRIORITY) val containerRequestBuffer = new ArrayBuffer[ContainerRequest]( @@ -468,11 +468,11 @@ private[yarn] class YarnAllocationHandler( amClient.addContainerRequest(request) } - if (numWorkers > 0) { - numPendingAllocate.addAndGet(numWorkers) - logInfo("Will Allocate %d worker containers, each with %d memory".format( - numWorkers, - (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))) + if (numExecutors > 0) { + numPendingAllocate.addAndGet(numExecutors) + logInfo("Will Allocate %d executor containers, each with %d memory".format( + numExecutors, + (executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD))) } else { logDebug("Empty allocation request ...") } @@ -494,7 +494,7 @@ private[yarn] class YarnAllocationHandler( private def createResourceRequests( requestType: AllocationType.AllocationType, resource: String, - numWorkers: Int, + numExecutors: Int, priority: Int ): ArrayBuffer[ContainerRequest] = { @@ -507,7 +507,7 @@ private[yarn] class YarnAllocationHandler( val nodeLocal = constructContainerRequests( Array(hostname), racks = null, - numWorkers, + numExecutors, priority) // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler. @@ -516,10 +516,10 @@ private[yarn] class YarnAllocationHandler( } case AllocationType.RACK => { val rack = resource - constructContainerRequests(hosts = null, Array(rack), numWorkers, priority) + constructContainerRequests(hosts = null, Array(rack), numExecutors, priority) } case AllocationType.ANY => constructContainerRequests( - hosts = null, racks = null, numWorkers, priority) + hosts = null, racks = null, numExecutors, priority) case _ => throw new IllegalArgumentException( "Unexpected/unsupported request type: " + requestType) } @@ -528,18 +528,18 @@ private[yarn] class YarnAllocationHandler( private def constructContainerRequests( hosts: Array[String], racks: Array[String], - numWorkers: Int, + numExecutors: Int, priority: Int ): ArrayBuffer[ContainerRequest] = { - val memoryRequest = workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD - val resource = Resource.newInstance(memoryRequest, workerCores) + val memoryRequest = executorMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val resource = Resource.newInstance(memoryRequest, executorCores) val prioritySetting = Records.newRecord(classOf[Priority]) prioritySetting.setPriority(priority) val requests = new ArrayBuffer[ContainerRequest]() - for (i <- 0 until numWorkers) { + for (i <- 0 until numExecutors) { requests += new ContainerRequest(resource, hosts, racks, prioritySetting) } requests @@ -574,9 +574,9 @@ object YarnAllocationHandler { conf, amClient, appAttemptId, - args.numWorkers, - args.workerMemory, - args.workerCores, + args.numExecutors, + args.executorMemory, + args.executorCores, Map[String, Int](), Map[String, Int](), sparkConf) @@ -596,9 +596,9 @@ object YarnAllocationHandler { conf, amClient, appAttemptId, - args.numWorkers, - args.workerMemory, - args.workerCores, + args.numExecutors, + args.executorMemory, + args.executorCores, hostToSplitCount, rackToSplitCount, sparkConf) @@ -608,9 +608,9 @@ object YarnAllocationHandler { conf: Configuration, amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, - maxWorkers: Int, - workerMemory: Int, - workerCores: Int, + maxExecutors: Int, + executorMemory: Int, + executorCores: Int, map: collection.Map[String, collection.Set[SplitInfo]], sparkConf: SparkConf ): YarnAllocationHandler = { @@ -619,9 +619,9 @@ object YarnAllocationHandler { conf, amClient, appAttemptId, - maxWorkers, - workerMemory, - workerCores, + maxExecutors, + executorMemory, + executorCores, hostToCount, rackToCount, sparkConf)