diff --git a/LICENSE b/LICENSE
index 1c166d1333614..1c1c2c0255fa9 100644
--- a/LICENSE
+++ b/LICENSE
@@ -396,3 +396,35 @@ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
+
+
+========================================================================
+For sbt and sbt-launch-lib.bash in sbt/:
+========================================================================
+
+// Generated from http://www.opensource.org/licenses/bsd-license.php
+Copyright (c) 2011, Paul Phillips.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+ * Neither the name of the author nor the names of its contributors may be
+ used to endorse or promote products derived from this software without
+ specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
+EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/NOTICE b/NOTICE
index 7cbb114b2ae2d..dce0c4eaf31ed 100644
--- a/NOTICE
+++ b/NOTICE
@@ -1,5 +1,5 @@
Apache Spark
-Copyright 2013 The Apache Software Foundation.
+Copyright 2014 The Apache Software Foundation.
This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).
diff --git a/README.md b/README.md
index c840a68f76b17..dc8135b9b8b51 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,12 @@
# Apache Spark
-Lightning-Fast Cluster Computing -
+Lightning-Fast Cluster Computing -
## Online Documentation
You can find the latest Spark documentation, including a programming
-guide, on the project webpage at .
+guide, on the project webpage at .
This README file only contains basic setup instructions.
@@ -92,21 +92,10 @@ If your project is built with Maven, add this to your POM file's `
## Configuration
-Please refer to the [Configuration guide](http://spark.incubator.apache.org/docs/latest/configuration.html)
+Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html)
in the online documentation for an overview on how to configure Spark.
-## Apache Incubator Notice
-
-Apache Spark is an effort undergoing incubation at The Apache Software
-Foundation (ASF), sponsored by the Apache Incubator. Incubation is required of
-all newly accepted projects until a further review indicates that the
-infrastructure, communications, and decision making process have stabilized in
-a manner consistent with other successful ASF projects. While incubation status
-is not necessarily a reflection of the completeness or stability of the code,
-it does indicate that the project has yet to be fully endorsed by the ASF.
-
-
## Contributing to Spark
Contributions via GitHub pull requests are gladly accepted from their original
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 82396040251d3..22bbbc57d81d4 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,17 +21,20 @@
org.apache.sparkspark-parent
- 1.0.0-incubating-SNAPSHOT
+ 1.0.0-SNAPSHOT../pom.xmlorg.apache.sparkspark-assembly_2.10Spark Project Assembly
- http://spark.incubator.apache.org/
+ http://spark.apache.org/
+ pom
- ${project.build.directory}/scala-${scala.binary.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar
+ scala-${scala.binary.version}
+ ${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar
+ ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename}spark/usr/share/sparkroot
@@ -155,6 +158,16 @@
+
+ spark-ganglia-lgpl
+
+
+ org.apache.spark
+ spark-ganglia-lgpl_${scala.binary.version}
+ ${project.version}
+
+
+ bigtop-dist
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala
index dd3eed8affe39..70c7474a936dc 100644
--- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala
+++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala
@@ -27,7 +27,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program.
- * @param sc [[org.apache.spark.SparkContext]] to use for the program.
+ * @param sc org.apache.spark.SparkContext to use for the program.
* @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the
* Key will be the vertex id.
* @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often
@@ -38,10 +38,10 @@ object Bagel extends Logging {
* @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices
* after each superstep and provides the result to each vertex in the next
* superstep.
- * @param partitioner [[org.apache.spark.Partitioner]] partitions values by key
+ * @param partitioner org.apache.spark.Partitioner partitions values by key
* @param numPartitions number of partitions across which to split the graph.
* Default is the default parallelism of the SparkContext
- * @param storageLevel [[org.apache.spark.storage.StorageLevel]] to use for caching of
+ * @param storageLevel org.apache.spark.storage.StorageLevel to use for caching of
* intermediate RDDs in each superstep. Defaults to caching in memory.
* @param compute function that takes a Vertex, optional set of (possibly combined) messages to
* the Vertex, optional Aggregator and the current superstep,
@@ -131,7 +131,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default
- * [[org.apache.spark.HashPartitioner]] and default storage level
+ * org.apache.spark.HashPartitioner and default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
@@ -146,7 +146,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the
- * default [[org.apache.spark.HashPartitioner]]
+ * default org.apache.spark.HashPartitioner
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
@@ -166,7 +166,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]],
- * default [[org.apache.spark.HashPartitioner]],
+ * default org.apache.spark.HashPartitioner,
* [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
@@ -180,7 +180,7 @@ object Bagel extends Logging {
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]],
- * the default [[org.apache.spark.HashPartitioner]]
+ * the default org.apache.spark.HashPartitioner
* and [[org.apache.spark.bagel.DefaultCombiner]]
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
diff --git a/bin/spark-class b/bin/spark-class
index c4225a392d6da..229ae2cebbab3 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -40,34 +40,46 @@ if [ -z "$1" ]; then
exit 1
fi
-# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable
-# values for that; it doesn't need a lot
-if [ "$1" = "org.apache.spark.deploy.master.Master" -o "$1" = "org.apache.spark.deploy.worker.Worker" ]; then
- SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
- SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
- # Do not overwrite SPARK_JAVA_OPTS environment variable in this script
- OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default
-else
- OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
+if [ -n "$SPARK_MEM" ]; then
+ echo "Warning: SPARK_MEM is deprecated, please use a more specific config option"
+ echo "(e.g., spark.executor.memory or SPARK_DRIVER_MEMORY)."
fi
+# Use SPARK_MEM or 512m as the default memory, to be overridden by specific options
+DEFAULT_MEM=${SPARK_MEM:-512m}
+
+SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
-# Add java opts for master, worker, executor. The opts maybe null
+# Add java opts and memory settings for master, worker, executors, and repl.
case "$1" in
+ # Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY.
'org.apache.spark.deploy.master.Master')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS"
+ OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_MASTER_OPTS"
+ OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM}
;;
'org.apache.spark.deploy.worker.Worker')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
+ OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_WORKER_OPTS"
+ OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM}
;;
+
+ # Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY.
'org.apache.spark.executor.CoarseGrainedExecutorBackend')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM}
;;
'org.apache.spark.executor.MesosExecutorBackend')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
+ OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM}
;;
+
+ # All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS.
'org.apache.spark.repl.Main')
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS"
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS"
+ OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM}
+ ;;
+ *)
+ OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
+ OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM}
;;
esac
@@ -83,14 +95,10 @@ else
fi
fi
-# Set SPARK_MEM if it isn't already set since we also use it for this process
-SPARK_MEM=${SPARK_MEM:-512m}
-export SPARK_MEM
-
# Set JAVA_OPTS to be able to load native libraries and to set heap size
JAVA_OPTS="$OUR_JAVA_OPTS"
JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH"
-JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM"
+JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
if [ -e "$FWDIR/conf/java-opts" ] ; then
JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`"
diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd
index 80818c78ec24b..f488cfdbeceb6 100755
--- a/bin/spark-class2.cmd
+++ b/bin/spark-class2.cmd
@@ -34,22 +34,45 @@ if not "x%1"=="x" goto arg_given
goto exit
:arg_given
-set RUNNING_DAEMON=0
-if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1
-if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
-if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
+if not "x%SPARK_MEM%"=="x" (
+ echo Warning: SPARK_MEM is deprecated, please use a more specific config option
+ echo e.g., spark.executor.memory or SPARK_DRIVER_MEMORY.
+)
+
+rem Use SPARK_MEM or 512m as the default memory, to be overridden by specific options
+set OUR_JAVA_MEM=%SPARK_MEM%
+if "x%OUR_JAVA_MEM%"=="x" set OUR_JAVA_MEM=512m
+
set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true
-if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
-rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script
-if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
-if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
-rem Figure out how much memory to use per executor and set it as an environment
-rem variable so that our process sees it and can report it to Mesos
-if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m
+rem Add java opts and memory settings for master, worker, executors, and repl.
+rem Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY.
+if "%1"=="org.apache.spark.deploy.master.Master" (
+ set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_MASTER_OPTS%
+ if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY%
+) else if "%1"=="org.apache.spark.deploy.worker.Worker" (
+ set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_WORKER_OPTS%
+ if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY%
+
+rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY.
+) else if "%1"=="org.apache.spark.executor.CoarseGrainedExecutorBackend" (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS%
+ if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY%
+) else if "%1"=="org.apache.spark.executor.MesosExecutorBackend" (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_EXECUTOR_OPTS%
+ if not "x%SPARK_EXECUTOR_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_EXECUTOR_MEMORY%
+
+rem All drivers use SPARK_JAVA_OPTS + SPARK_DRIVER_MEMORY. The repl also uses SPARK_REPL_OPTS.
+) else if "%1"=="org.apache.spark.repl.Main" (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS% %SPARK_REPL_OPTS%
+ if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY%
+) else (
+ set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
+ if not "x%SPARK_DRIVER_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DRIVER_MEMORY%
+)
rem Set JAVA_OPTS to be able to load native libraries and to set heap size
-set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
+set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%OUR_JAVA_MEM% -Xmx%OUR_JAVA_MEM%
rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
rem Test whether the user has built Spark
diff --git a/bin/spark-shell b/bin/spark-shell
index 2bff06cf70051..7d3fe3aca7f1d 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -45,13 +45,11 @@ if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then
exit
fi
-SPARK_SHELL_OPTS=""
-
for o in "$@"; do
if [ "$1" = "-c" -o "$1" = "--cores" ]; then
shift
if [[ "$1" =~ $CORE_PATTERN ]]; then
- SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.cores.max=$1"
+ SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.cores.max=$1"
shift
else
echo "ERROR: wrong format for -c/--cores"
@@ -61,7 +59,7 @@ for o in "$@"; do
if [ "$1" = "-em" -o "$1" = "--execmem" ]; then
shift
if [[ $1 =~ $MEM_PATTERN ]]; then
- SPARK_SHELL_OPTS="$SPARK_SHELL_OPTS -Dspark.executor.memory=$1"
+ SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.executor.memory=$1"
shift
else
echo "ERROR: wrong format for --execmem/-em"
@@ -71,7 +69,7 @@ for o in "$@"; do
if [ "$1" = "-dm" -o "$1" = "--drivermem" ]; then
shift
if [[ $1 =~ $MEM_PATTERN ]]; then
- export SPARK_MEM=$1
+ export SPARK_DRIVER_MEMORY=$1
shift
else
echo "ERROR: wrong format for --drivermem/-dm"
@@ -125,16 +123,18 @@ if [[ ! $? ]]; then
fi
if $cygwin; then
- # Workaround for issue involving JLine and Cygwin
- # (see http://sourceforge.net/p/jline/bugs/40/).
- # If you're using the Mintty terminal emulator in Cygwin, may need to set the
- # "Backspace sends ^H" setting in "Keys" section of the Mintty options
- # (see https://github.com/sbt/sbt/issues/562).
- stty -icanon min 1 -echo > /dev/null 2>&1
- $FWDIR/bin/spark-class -Djline.terminal=unix $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@"
- stty icanon echo > /dev/null 2>&1
+ # Workaround for issue involving JLine and Cygwin
+ # (see http://sourceforge.net/p/jline/bugs/40/).
+ # If you're using the Mintty terminal emulator in Cygwin, may need to set the
+ # "Backspace sends ^H" setting in "Keys" section of the Mintty options
+ # (see https://github.com/sbt/sbt/issues/562).
+ stty -icanon min 1 -echo > /dev/null 2>&1
+ export SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Djline.terminal=unix"
+ $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@"
+ stty icanon echo > /dev/null 2>&1
else
- $FWDIR/bin/spark-class $SPARK_SHELL_OPTS org.apache.spark.repl.Main "$@"
+ export SPARK_REPL_OPTS
+ $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@"
fi
# record the exit status lest it be overwritten:
diff --git a/core/pom.xml b/core/pom.xml
index 5576b0c3b4795..a6f478b09bda0 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -17,242 +17,264 @@
-->
- 4.0.0
-
- org.apache.spark
- spark-parent
- 1.0.0-incubating-SNAPSHOT
- ../pom.xml
-
-
+ 4.0.0
+ org.apache.spark
- spark-core_2.10
- jar
- Spark Project Core
- http://spark.incubator.apache.org/
+ spark-parent
+ 1.0.0-SNAPSHOT
+ ../pom.xml
+
-
-
- org.apache.hadoop
- hadoop-client
-
-
- net.java.dev.jets3t
- jets3t
-
-
- commons-logging
- commons-logging
-
-
-
-
- org.apache.avro
- avro
-
-
- org.apache.avro
- avro-ipc
-
-
- org.apache.zookeeper
- zookeeper
-
-
- org.eclipse.jetty
- jetty-server
-
-
- com.google.guava
- guava
-
-
- com.google.code.findbugs
- jsr305
-
-
- org.slf4j
- slf4j-api
-
-
- org.slf4j
- jul-to-slf4j
-
-
- org.slf4j
- jcl-over-slf4j
-
-
- log4j
- log4j
-
-
- org.slf4j
- slf4j-log4j12
-
-
- com.ning
- compress-lzf
-
-
- org.xerial.snappy
- snappy-java
-
-
- org.ow2.asm
- asm
-
-
- com.twitter
- chill_${scala.binary.version}
- 0.3.1
-
-
- com.twitter
- chill-java
- 0.3.1
-
-
- ${akka.group}
- akka-remote_${scala.binary.version}
-
-
- ${akka.group}
- akka-slf4j_${scala.binary.version}
-
-
- ${akka.group}
- akka-testkit_${scala.binary.version}
- test
-
-
- org.scala-lang
- scala-library
-
-
- net.liftweb
- lift-json_${scala.binary.version}
-
-
- it.unimi.dsi
- fastutil
-
-
- colt
- colt
-
-
- org.apache.mesos
- mesos
-
-
- io.netty
- netty-all
-
-
- com.clearspring.analytics
- stream
-
-
- com.codahale.metrics
- metrics-core
-
-
- com.codahale.metrics
- metrics-jvm
-
-
- com.codahale.metrics
- metrics-json
-
-
- com.codahale.metrics
- metrics-ganglia
-
-
- com.codahale.metrics
- metrics-graphite
-
-
- org.apache.derby
- derby
- test
-
-
- commons-io
- commons-io
- test
-
-
- org.scalatest
- scalatest_${scala.binary.version}
- test
-
-
- org.mockito
- mockito-all
- test
-
-
- org.scalacheck
- scalacheck_${scala.binary.version}
- test
-
-
- org.easymock
- easymock
- test
-
-
- com.novocode
- junit-interface
- test
-
-
-
- target/scala-${scala.binary.version}/classes
- target/scala-${scala.binary.version}/test-classes
-
-
- org.apache.maven.plugins
- maven-antrun-plugin
-
-
- test
-
- run
-
-
- true
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- org.scalatest
- scalatest-maven-plugin
-
-
- ${basedir}/..
- 1
- ${spark.classpath}
-
-
-
-
-
+ org.apache.spark
+ spark-core_2.10
+ jar
+ Spark Project Core
+ http://spark.apache.org/
+
+
+
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
+
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ net.java.dev.jets3t
+ jets3t
+
+
+ commons-logging
+ commons-logging
+
+
+
+
+ org.apache.curator
+ curator-recipes
+
+
+ org.eclipse.jetty
+ jetty-plus
+
+
+ org.eclipse.jetty
+ jetty-security
+
+
+ org.eclipse.jetty
+ jetty-util
+
+
+ org.eclipse.jetty
+ jetty-server
+
+
+ com.google.guava
+ guava
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.slf4j
+ jul-to-slf4j
+
+
+ org.slf4j
+ jcl-over-slf4j
+
+
+ log4j
+ log4j
+
+
+ org.slf4j
+ slf4j-log4j12
+
+
+ com.ning
+ compress-lzf
+
+
+ org.xerial.snappy
+ snappy-java
+
+
+ com.twitter
+ chill_${scala.binary.version}
+ 0.3.1
+
+
+ com.twitter
+ chill-java
+ 0.3.1
+
+
+ commons-net
+ commons-net
+
+
+ ${akka.group}
+ akka-remote_${scala.binary.version}
+
+
+ ${akka.group}
+ akka-slf4j_${scala.binary.version}
+
+
+ ${akka.group}
+ akka-testkit_${scala.binary.version}
+ test
+
+
+ org.scala-lang
+ scala-library
+
+
+ org.json4s
+ json4s-jackson_${scala.binary.version}
+ 3.2.6
+
+
+
+ org.scala-lang
+ scalap
+
+
+
+
+ it.unimi.dsi
+ fastutil
+
+
+ colt
+ colt
+
+
+ org.apache.mesos
+ mesos
+
+
+ io.netty
+ netty-all
+
+
+ com.clearspring.analytics
+ stream
+
+
+ com.codahale.metrics
+ metrics-core
+
+
+ com.codahale.metrics
+ metrics-jvm
+
+
+ com.codahale.metrics
+ metrics-json
+
+
+ com.codahale.metrics
+ metrics-graphite
+
+
+ org.apache.derby
+ derby
+ test
+
+
+ commons-io
+ commons-io
+ test
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+ org.mockito
+ mockito-all
+ test
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.easymock
+ easymock
+ test
+
+
+ com.novocode
+ junit-interface
+ test
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-antrun-plugin
+
+
+ test
+
+ run
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ ${basedir}/..
+ 1
+ ${spark.classpath}
+
+
+
+
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
similarity index 69%
rename from core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala
rename to core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
index 7500a8943634b..57fd0a7a80494 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -15,16 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package org.apache.spark.api.java.function;
-import java.lang.{Double => JDouble, Iterable => JIterable}
+import java.io.Serializable;
/**
* A function that returns zero or more records of type Double from each input record.
*/
-// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
-// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
-abstract class DoubleFlatMapFunction[T] extends WrappedFunction1[T, JIterable[JDouble]]
- with Serializable {
- // Intentionally left blank
+public interface DoubleFlatMapFunction extends Serializable {
+ public Iterable call(T t) throws Exception;
}
diff --git a/project/project/SparkPluginBuild.scala b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
similarity index 69%
rename from project/project/SparkPluginBuild.scala
rename to core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
index a88a5e14539ec..150144e0e418c 100644
--- a/project/project/SparkPluginBuild.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java
@@ -15,12 +15,13 @@
* limitations under the License.
*/
-import sbt._
+package org.apache.spark.api.java.function;
-object SparkPluginDef extends Build {
- lazy val root = Project("plugins", file(".")) dependsOn(junitXmlListener)
- /* This is not published in a Maven repository, so we get it from GitHub directly */
- lazy val junitXmlListener = uri(
- "https://github.com/chenkelmann/junit_xml_listener.git#3f8029fbfda54dc7a68b1afd2f885935e1090016"
- )
+import java.io.Serializable;
+
+/**
+ * A function that returns Doubles, and can be used to construct DoubleRDDs.
+ */
+public interface DoubleFunction extends Serializable {
+ public double call(T t) throws Exception;
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
similarity index 79%
rename from core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
index bdb01f7670356..fa75842047c6a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package org.apache.spark.api.java.function;
-import scala.reflect.ClassTag
+import java.io.Serializable;
/**
* A function that returns zero or more output records from each input record.
*/
-abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
-}
+public interface FlatMapFunction extends Serializable {
+ public Iterable call(T t) throws Exception;
+}
\ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
similarity index 78%
rename from core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
rename to core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
index aae1349c5e17c..d1fdec072443d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package org.apache.spark.api.java.function;
-import scala.reflect.ClassTag
+import java.io.Serializable;
/**
* A function that takes two inputs and returns zero or more output records.
*/
-abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
-}
+public interface FlatMapFunction2 extends Serializable {
+ public Iterable call(T1 t1, T2 t2) throws Exception;
+}
\ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.scala b/core/src/main/java/org/apache/spark/api/java/function/Function.java
similarity index 72%
rename from core/src/main/scala/org/apache/spark/api/java/function/Function.scala
rename to core/src/main/java/org/apache/spark/api/java/function/Function.java
index a5e1701f7718f..d00551bb0add6 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java
@@ -15,17 +15,15 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package org.apache.spark.api.java.function;
-import scala.reflect.ClassTag
-import org.apache.spark.api.java.JavaSparkContext
+import java.io.Serializable;
/**
- * Base class for functions whose return types do not create special RDDs. PairFunction and
+ * Base interface for functions whose return types do not create special RDDs. PairFunction and
* DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed
* when mapping RDDs of other types.
*/
-abstract class Function[T, R] extends WrappedFunction1[T, R] with Serializable {
- def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag
+public interface Function extends Serializable {
+ public R call(T1 v1) throws Exception;
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.scala b/core/src/main/java/org/apache/spark/api/java/function/Function2.java
similarity index 76%
rename from core/src/main/scala/org/apache/spark/api/java/function/Function2.scala
rename to core/src/main/java/org/apache/spark/api/java/function/Function2.java
index fa3616cbcb4d2..793caaa61ac5a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function2.java
@@ -15,15 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package org.apache.spark.api.java.function;
-import scala.reflect.ClassTag
-import org.apache.spark.api.java.JavaSparkContext
+import java.io.Serializable;
/**
* A two-argument function that takes arguments of type T1 and T2 and returns an R.
*/
-abstract class Function2[T1, T2, R] extends WrappedFunction2[T1, T2, R] with Serializable {
- def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag
+public interface Function2 extends Serializable {
+ public R call(T1 v1, T2 v2) throws Exception;
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.scala b/core/src/main/java/org/apache/spark/api/java/function/Function3.java
similarity index 75%
rename from core/src/main/scala/org/apache/spark/api/java/function/Function3.scala
rename to core/src/main/java/org/apache/spark/api/java/function/Function3.java
index 45152891e9272..b4151c3417df4 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function3.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function3.java
@@ -15,14 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package org.apache.spark.api.java.function;
-import org.apache.spark.api.java.JavaSparkContext
-import scala.reflect.ClassTag
+import java.io.Serializable;
/**
* A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
*/
-abstract class Function3[T1, T2, T3, R] extends WrappedFunction3[T1, T2, T3, R] with Serializable {
- def returnType(): ClassTag[R] = JavaSparkContext.fakeClassTag
+public interface Function3 extends Serializable {
+ public R call(T1 v1, T2 v2, T3 v3) throws Exception;
}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java
new file mode 100644
index 0000000000000..691ef2eceb1f6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+import scala.Tuple2;
+
+/**
+ * A function that returns zero or more key-value pair records from each input record. The
+ * key-value pairs are represented as scala.Tuple2 objects.
+ */
+public interface PairFlatMapFunction extends Serializable {
+ public Iterable> call(T t) throws Exception;
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
new file mode 100644
index 0000000000000..abd9bcc07ac61
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+
+import scala.Tuple2;
+
+/**
+ * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs.
+ */
+public interface PairFunction extends Serializable {
+ public Tuple2 call(T t) throws Exception;
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java
similarity index 77%
rename from core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala
rename to core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java
index 2e0b0e6eda765..2a10435b7523a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.storage
+package org.apache.spark.api.java.function;
-private[spark] trait BlockFetchTracker {
- def totalBlocks : Int
- def numLocalBlocks: Int
- def numRemoteBlocks: Int
- def remoteFetchTime : Long
- def fetchWaitTime: Long
- def remoteBytesRead : Long
+import java.io.Serializable;
+
+/**
+ * A function with no return value.
+ */
+public interface VoidFunction extends Serializable {
+ public void call(T t) throws Exception;
}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
index 754b46a4c7df2..a67392441ed29 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -79,7 +79,6 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
- shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 1daabecf23292..872e892c04fe6 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -71,10 +71,30 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val computedValues = rdd.computeOrReadCheckpoint(split, context)
// Persist the result, so long as the task is not running locally
if (context.runningLocally) { return computedValues }
- val elements = new ArrayBuffer[Any]
- elements ++= computedValues
- blockManager.put(key, elements, storageLevel, tellMaster = true)
- elements.iterator.asInstanceOf[Iterator[T]]
+ if (storageLevel.useDisk && !storageLevel.useMemory) {
+ // In the case that this RDD is to be persisted using DISK_ONLY
+ // the iterator will be passed directly to the blockManager (rather then
+ // caching it to an ArrayBuffer first), then the resulting block data iterator
+ // will be passed back to the user. If the iterator generates a lot of data,
+ // this means that it doesn't all have to be held in memory at one time.
+ // This could also apply to MEMORY_ONLY_SER storage, but we need to make sure
+ // blocks aren't dropped by the block store before enabling that.
+ blockManager.put(key, computedValues, storageLevel, tellMaster = true)
+ return blockManager.get(key) match {
+ case Some(values) =>
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
+ case None =>
+ logInfo("Failure to store %s".format(key))
+ throw new Exception("Block manager failed to return persisted valued")
+ }
+ } else {
+ // In this case the RDD is cached to an array buffer. This will save the results
+ // if we're dealing with a 'one-time' iterator
+ val elements = new ArrayBuffer[Any]
+ elements ++= computedValues
+ blockManager.put(key, elements, storageLevel, tellMaster = true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ }
} finally {
loading.synchronized {
loading.remove(key)
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index cc30105940d1a..448f87b81ef4a 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
/**
* Base class for dependencies.
@@ -43,12 +44,13 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
- * @param serializerClass class name of the serializer to use
+ * @param serializer [[Serializer]] to use. If set to null, the default serializer, as specified
+ * by `spark.serializer` config option, will be used.
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
- val serializerClass: String = null)
+ val serializer: Serializer = null)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index d3264a4bb3c81..3d7692ea8a49e 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -23,7 +23,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
-private[spark] class HttpFileServer extends Logging {
+private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging {
var baseDir : File = null
var fileDir : File = null
@@ -38,9 +38,10 @@ private[spark] class HttpFileServer extends Logging {
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir)
+ httpServer = new HttpServer(baseDir, securityManager)
httpServer.start()
serverUri = httpServer.uri
+ logDebug("HTTP file server started at: " + serverUri)
}
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 759e68ee0cc61..cb5df25fa48df 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -19,15 +19,18 @@ package org.apache.spark
import java.io.File
+import org.eclipse.jetty.util.security.{Constraint, Password}
+import org.eclipse.jetty.security.authentication.DigestAuthenticator
+import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler}
+
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
-import org.eclipse.jetty.server.handler.DefaultHandler
-import org.eclipse.jetty.server.handler.HandlerList
-import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.apache.spark.util.Utils
+
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
@@ -38,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
-private[spark] class HttpServer(resourceBase: File) extends Logging {
+private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager)
+ extends Logging {
private var server: Server = null
private var port: Int = -1
@@ -59,14 +63,60 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
server.setThreadPool(threadPool)
val resHandler = new ResourceHandler
resHandler.setResourceBase(resourceBase.getAbsolutePath)
+
val handlerList = new HandlerList
handlerList.setHandlers(Array(resHandler, new DefaultHandler))
- server.setHandler(handlerList)
+
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("HttpServer is using security")
+ val sh = setupSecurityHandler(securityManager)
+ // make sure we go through security handler to get resources
+ sh.setHandler(handlerList)
+ server.setHandler(sh)
+ } else {
+ logDebug("HttpServer is not using security")
+ server.setHandler(handlerList)
+ }
+
server.start()
port = server.getConnectors()(0).getLocalPort()
}
}
+ /**
+ * Setup Jetty to the HashLoginService using a single user with our
+ * shared secret. Configure it to use DIGEST-MD5 authentication so that the password
+ * isn't passed in plaintext.
+ */
+ private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = {
+ val constraint = new Constraint()
+ // use DIGEST-MD5 as the authentication mechanism
+ constraint.setName(Constraint.__DIGEST_AUTH)
+ constraint.setRoles(Array("user"))
+ constraint.setAuthenticate(true)
+ constraint.setDataConstraint(Constraint.DC_NONE)
+
+ val cm = new ConstraintMapping()
+ cm.setConstraint(constraint)
+ cm.setPathSpec("/*")
+ val sh = new ConstraintSecurityHandler()
+
+ // the hashLoginService lets us do a single user and
+ // secret right now. This could be changed to use the
+ // JAASLoginService for other options.
+ val hashLogin = new HashLoginService()
+
+ val userCred = new Password(securityMgr.getSecretKey())
+ if (userCred == null) {
+ throw new Exception("Error: secret key is null with authentication on")
+ }
+ hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user"))
+ sh.setLoginService(hashLogin)
+ sh.setAuthenticator(new DigestAuthenticator());
+ sh.setConstraintMappings(Array(cm))
+ sh
+ }
+
def stop() {
if (server == null) {
throw new ServerStateException("Server is already stopped")
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index b749e5414dab6..7423082e34f47 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import org.apache.log4j.{LogManager, PropertyConfigurator}
import org.slf4j.{Logger, LoggerFactory}
+import org.slf4j.impl.StaticLoggerBinder
/**
* Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows
@@ -101,9 +102,11 @@ trait Logging {
}
private def initializeLogging() {
- // If Log4j doesn't seem initialized, load a default properties file
+ // If Log4j is being used, but is not initialized, load a default properties file
+ val binder = StaticLoggerBinder.getSingleton
+ val usingLog4j = binder.getLoggerFactoryClassStr.endsWith("Log4jLoggerFactory")
val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4jInitialized) {
+ if (!log4jInitialized && usingLog4j) {
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
val classLoader = this.getClass.getClassLoader
Option(classLoader.getResource(defaultLogProps)) match {
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 5968973132942..80cbf951cb70e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -35,13 +35,28 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with Logging {
+ val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+
def receive = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
- sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
+ val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
+ val serializedSize = mapOutputStatuses.size
+ if (serializedSize > maxAkkaFrameSize) {
+ val msg = s"Map output statuses were $serializedSize bytes which " +
+ s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
+
+ /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
+ * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
+ * will ultimately remove this entire code path. */
+ val exception = new SparkException(msg)
+ logError(msg, exception)
+ throw exception
+ }
+ sender ! mapOutputStatuses
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
new file mode 100644
index 0000000000000..591978c1d3630
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.net.{Authenticator, PasswordAuthentication}
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.Credentials
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.spark.deploy.SparkHadoopUtil
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Spark class responsible for security.
+ *
+ * In general this class should be instantiated by the SparkEnv and most components
+ * should access it from that. There are some cases where the SparkEnv hasn't been
+ * initialized yet and this class must be instantiated directly.
+ *
+ * Spark currently supports authentication via a shared secret.
+ * Authentication can be configured to be on via the 'spark.authenticate' configuration
+ * parameter. This parameter controls whether the Spark communication protocols do
+ * authentication using the shared secret. This authentication is a basic handshake to
+ * make sure both sides have the same shared secret and are allowed to communicate.
+ * If the shared secret is not identical they will not be allowed to communicate.
+ *
+ * The Spark UI can also be secured by using javax servlet filters. A user may want to
+ * secure the UI if it has data that other users should not be allowed to see. The javax
+ * servlet filter specified by the user can authenticate the user and then once the user
+ * is logged in, Spark can compare that user versus the view acls to make sure they are
+ * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls'
+ * control the behavior of the acls. Note that the person who started the application
+ * always has view access to the UI.
+ *
+ * Spark does not currently support encryption after authentication.
+ *
+ * At this point spark has multiple communication protocols that need to be secured and
+ * different underlying mechanisms are used depending on the protocol:
+ *
+ * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality.
+ * Akka remoting allows you to specify a secure cookie that will be exchanged
+ * and ensured to be identical in the connection handshake between the client
+ * and the server. If they are not identical then the client will be refused
+ * to connect to the server. There is no control of the underlying
+ * authentication mechanism so its not clear if the password is passed in
+ * plaintext or uses DIGEST-MD5 or some other mechanism.
+ * Akka also has an option to turn on SSL, this option is not currently supported
+ * but we could add a configuration option in the future.
+ *
+ * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
+ * for the HttpServer. Jetty supports multiple authentication mechanisms -
+ * Basic, Digest, Form, Spengo, etc. It also supports multiple different login
+ * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService
+ * to authenticate using DIGEST-MD5 via a single user and the shared secret.
+ * Since we are using DIGEST-MD5, the shared secret is not passed on the wire
+ * in plaintext.
+ * We currently do not support SSL (https), but Jetty can be configured to use it
+ * so we could add a configuration option for this in the future.
+ *
+ * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5.
+ * Any clients must specify the user and password. There is a default
+ * Authenticator installed in the SecurityManager to how it does the authentication
+ * and in this case gets the user name and password from the request.
+ *
+ * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * exchange messages. For this we use the Java SASL
+ * (Simple Authentication and Security Layer) API and again use DIGEST-MD5
+ * as the authentication mechanism. This means the shared secret is not passed
+ * over the wire in plaintext.
+ * Note that SASL is pluggable as to what mechanism it uses. We currently use
+ * DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
+ * Spark currently supports "auth" for the quality of protection, which means
+ * the connection is not supporting integrity or privacy protection (encryption)
+ * after authentication. SASL also supports "auth-int" and "auth-conf" which
+ * SPARK could be support in the future to allow the user to specify the quality
+ * of protection they want. If we support those, the messages will also have to
+ * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
+ *
+ * Since the connectionManager does asynchronous messages passing, the SASL
+ * authentication is a bit more complex. A ConnectionManager can be both a client
+ * and a Server, so for a particular connection is has to determine what to do.
+ * A ConnectionId was added to be able to track connections and is used to
+ * match up incoming messages with connections waiting for authentication.
+ * If its acting as a client and trying to send a message to another ConnectionManager,
+ * it blocks the thread calling sendMessage until the SASL negotiation has occurred.
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId
+ * and waits for the response from the server and does the handshake.
+ *
+ * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
+ * can be used. Yarn requires a specific AmIpFilter be installed for security to work
+ * properly. For non-Yarn deployments, users can write a filter to go through a
+ * companies normal login service. If an authentication filter is in place then the
+ * SparkUI can be configured to check the logged in user against the list of users who
+ * have view acls to see if that user is authorized.
+ * The filters can also be used for many different purposes. For instance filters
+ * could be used for logging, encryption, or compression.
+ *
+ * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ *
+ * For Yarn deployments, the secret is automatically generated using the Akka remote
+ * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
+ * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels
+ * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn
+ * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn
+ * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there
+ * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use
+ * filters to do authentication. That authentication then happens via the ResourceManager Proxy
+ * and Spark will use that to do authorization against the view acls.
+ *
+ * For other Spark deployments, the shared secret must be specified via the
+ * spark.authenticate.secret config.
+ * All the nodes (Master and Workers) and the applications need to have the same shared secret.
+ * This again is not ideal as one user could potentially affect another users application.
+ * This should be enhanced in the future to provide better protection.
+ * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * authentication. Spark will then use that user to compare against the view acls to do
+ * authorization. If not filter is in place the user is generally null and no authorization
+ * can take place.
+ */
+
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+
+ // key used to store the spark secret in the Hadoop UGI
+ private val sparkSecretLookupKey = "sparkCookie"
+
+ private val authOn = sparkConf.getBoolean("spark.authenticate", false)
+ private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false)
+
+ // always add the current user and SPARK_USER to the viewAcls
+ private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""),
+ Option(System.getenv("SPARK_USER")).getOrElse(""))
+ aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',')
+ private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet
+
+ private val secretKey = generateSecretKey()
+ logInfo("SecurityManager, is authentication enabled: " + authOn +
+ " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString())
+
+ // Set our own authenticator to properly negotiate user/password for HTTP connections.
+ // This is needed by the HTTP client fetching from the HttpServer. Put here so its
+ // only set once.
+ if (authOn) {
+ Authenticator.setDefault(
+ new Authenticator() {
+ override def getPasswordAuthentication(): PasswordAuthentication = {
+ var passAuth: PasswordAuthentication = null
+ val userInfo = getRequestingURL().getUserInfo()
+ if (userInfo != null) {
+ val parts = userInfo.split(":", 2)
+ passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray())
+ }
+ return passAuth
+ }
+ }
+ )
+ }
+
+ /**
+ * Generates or looks up the secret key.
+ *
+ * The way the key is stored depends on the Spark deployment mode. Yarn
+ * uses the Hadoop UGI.
+ *
+ * For non-Yarn deployments, If the config variable is not set
+ * we throw an exception.
+ */
+ private def generateSecretKey(): String = {
+ if (!isAuthenticationEnabled) return null
+ // first check to see if the secret is already set, else generate a new one if on yarn
+ val sCookie = if (SparkHadoopUtil.get.isYarnMode) {
+ val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey)
+ if (secretKey != null) {
+ logDebug("in yarn mode, getting secret from credentials")
+ return new Text(secretKey).toString
+ } else {
+ logDebug("getSecretKey: yarn mode, secret key from credentials is null")
+ }
+ val cookie = akka.util.Crypt.generateSecureCookie
+ // if we generated the secret then we must be the first so lets set it so t
+ // gets used by everyone else
+ SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie)
+ logInfo("adding secret to credentials in yarn mode")
+ cookie
+ } else {
+ // user must have set spark.authenticate.secret config
+ sparkConf.getOption("spark.authenticate.secret") match {
+ case Some(value) => value
+ case None => throw new Exception("Error: a secret key must be specified via the " +
+ "spark.authenticate.secret config")
+ }
+ }
+ sCookie
+ }
+
+ /**
+ * Check to see if Acls for the UI are enabled
+ * @return true if UI authentication is enabled, otherwise false
+ */
+ def uiAclsEnabled(): Boolean = uiAclsOn
+
+ /**
+ * Checks the given user against the view acl list to see if they have
+ * authorization to view the UI. If the UI acls must are disabled
+ * via spark.ui.acls.enable, all users have view access.
+ *
+ * @param user to see if is authorized
+ * @return true is the user has permission, otherwise false
+ */
+ def checkUIViewPermissions(user: String): Boolean = {
+ if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true
+ }
+
+ /**
+ * Check to see if authentication for the Spark communication protocols is enabled
+ * @return true if authentication is enabled, otherwise false
+ */
+ def isAuthenticationEnabled(): Boolean = authOn
+
+ /**
+ * Gets the user used for authenticating HTTP connections.
+ * For now use a single hardcoded user.
+ * @return the HTTP user as a String
+ */
+ def getHttpUser(): String = "sparkHttpUser"
+
+ /**
+ * Gets the user used for authenticating SASL connections.
+ * For now use a single hardcoded user.
+ * @return the SASL user as a String
+ */
+ def getSaslUser(): String = "sparkSaslUser"
+
+ /**
+ * Gets the secret key.
+ * @return the secret key as a String if authentication is enabled, otherwise returns null
+ */
+ def getSecretKey(): String = secretKey
+}
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
index e8f756c408889..a4f69b6b22b2c 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -29,7 +29,7 @@ private[spark] abstract class ShuffleFetcher {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
- serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
+ serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
/** Stop the fetcher */
def stop() {}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index a24f07e9a6e9a..852ed8fe1fb91 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -130,6 +130,8 @@ class SparkContext(
val isLocal = (master == "local" || master.startsWith("local["))
+ if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.create(
conf,
@@ -160,19 +162,20 @@ class SparkContext(
jars.foreach(addJar)
}
+ def warnSparkMem(value: String): String = {
+ logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " +
+ "deprecated, please use spark.executor.memory instead.")
+ value
+ }
+
private[spark] val executorMemory = conf.getOption("spark.executor.memory")
- .orElse(Option(System.getenv("SPARK_MEM")))
+ .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY")))
+ .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem))
.map(Utils.memoryStringToMb)
.getOrElse(512)
- if (!conf.contains("spark.executor.memory") && sys.env.contains("SPARK_MEM")) {
- logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " +
- "deprecated, instead use spark.executor.memory")
- }
-
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
- // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS");
value <- Option(System.getenv(key))) {
executorEnvs(key) = value
@@ -183,8 +186,9 @@ class SparkContext(
value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
executorEnvs(envKey) = value
}
- // Since memory can be set with a system property too, use that
- executorEnvs("SPARK_MEM") = executorMemory + "m"
+ // The Mesos scheduler backend relies on this environment variable to set executor memory.
+ // TODO: Set this only in the Mesos scheduler.
+ executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m"
executorEnvs ++= conf.getExecutorEnv
// Set SPARK_USER for user who is running SparkContext.
@@ -240,6 +244,7 @@ class SparkContext(
localProperties.set(props)
}
+ @deprecated("Properties no longer need to be explicitly initialized.", "1.0.0")
def initLocalProperties() {
localProperties.set(new Properties())
}
@@ -308,7 +313,7 @@ class SparkContext(
private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
- def initDriverMetrics() {
+ private def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
SparkEnv.get.metricsSystem.registerSource(blockManagerSource)
}
@@ -350,7 +355,7 @@ class SparkContext(
* using the older MapReduce API (`org.apache.hadoop.mapred`).
*
* @param conf JobConf for setting up the dataset
- * @param inputFormatClass Class of the [[InputFormat]]
+ * @param inputFormatClass Class of the InputFormat
* @param keyClass Class of the keys
* @param valueClass Class of the values
* @param minSplits Minimum number of Hadoop Splits to generate.
@@ -633,7 +638,7 @@ class SparkContext(
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
- Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@@ -735,8 +740,10 @@ class SparkContext(
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
- if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") {
- // In order for this to work in yarn standalone mode the user must specify the
+ // yarn-standalone is deprecated, but still supported
+ if (SparkHadoopUtil.get.isYarnMode() &&
+ (master == "yarn-standalone" || master == "yarn-cluster")) {
+ // In order for this to work in yarn-cluster mode the user must specify the
// --addjars option to the client to upload the file into the distributed cache
// of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
@@ -825,13 +832,12 @@ class SparkContext(
setLocalProperty("externalCallSite", null)
}
+ /**
+ * Capture the current user callsite and return a formatted version for printing. If the user
+ * has overridden the call site, this will return the user's version.
+ */
private[spark] def getCallSite(): String = {
- val callSite = getLocalProperty("externalCallSite")
- if (callSite == null) {
- Utils.formatSparkCallSite
- } else {
- callSite
- }
+ Option(getLocalProperty("externalCallSite")).getOrElse(Utils.formatCallSiteInfo())
}
/**
@@ -846,6 +852,9 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
+ partitions.foreach{ p =>
+ require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p")
+ }
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
@@ -949,6 +958,9 @@ class SparkContext(
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ partitions.foreach{ p =>
+ require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p")
+ }
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1024,7 +1036,7 @@ class SparkContext(
* The SparkContext object contains a number of implicit conversions and parameters for use with
* various Spark features.
*/
-object SparkContext {
+object SparkContext extends Logging {
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
@@ -1242,7 +1254,11 @@ object SparkContext {
}
scheduler
- case "yarn-standalone" =>
+ case "yarn-standalone" | "yarn-cluster" =>
+ if (master == "yarn-standalone") {
+ logWarning(
+ "\"yarn-standalone\" is deprecated as of Spark 1.0. Use \"yarn-cluster\" instead.")
+ }
val scheduler = try {
val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 7ac65828f670f..774cbd6441a48 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -28,7 +28,7 @@ import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.storage.{BlockManager, BlockManagerMaster, BlockManagerMasterActor}
import org.apache.spark.network.ConnectionManager
-import org.apache.spark.serializer.{Serializer, SerializerManager}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.{AkkaUtils, Utils}
/**
@@ -41,7 +41,6 @@ import org.apache.spark.util.{AkkaUtils, Utils}
class SparkEnv private[spark] (
val executorId: String,
val actorSystem: ActorSystem,
- val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@@ -53,7 +52,8 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- val conf: SparkConf) extends Logging {
+ val conf: SparkConf,
+ val securityManager: SecurityManager) extends Logging {
// A mapping of thread ID to amount of memory used for shuffle in bytes
// All accesses should be manually synchronized
@@ -122,8 +122,9 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port,
- conf = conf)
+ val securityManager = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
+ securityManager = securityManager)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
@@ -137,17 +138,22 @@ object SparkEnv extends Logging {
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
- Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
+ val cls = Class.forName(name, true, classLoader)
+ // First try with the constructor that takes SparkConf. If we can't find one,
+ // use a no-arg constructor instead.
+ try {
+ cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ } catch {
+ case _: NoSuchMethodException =>
+ cls.getConstructor().newInstance().asInstanceOf[T]
+ }
}
- val serializerManager = new SerializerManager
-
- val serializer = serializerManager.setDefault(
- conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf)
+ val serializer = instantiateClass[Serializer](
+ "spark.serializer", "org.apache.spark.serializer.JavaSerializer")
- val closureSerializer = serializerManager.get(
- conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"),
- conf)
+ val closureSerializer = instantiateClass[Serializer](
+ "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
@@ -167,12 +173,12 @@ object SparkEnv extends Logging {
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf)), conf)
- val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf)
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
+ serializer, conf, securityManager)
val connectionManager = blockManager.connectionManager
- val broadcastManager = new BroadcastManager(isDriver, conf)
+ val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val cacheManager = new CacheManager(blockManager)
@@ -185,19 +191,19 @@ object SparkEnv extends Logging {
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
- new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
- val httpFileServer = new HttpFileServer()
+ val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
val metricsSystem = if (isDriver) {
- MetricsSystem.createMetricsSystem("driver", conf)
+ MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
- MetricsSystem.createMetricsSystem("executor", conf)
+ MetricsSystem.createMetricsSystem("executor", conf, securityManager)
}
metricsSystem.start()
@@ -219,7 +225,6 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
- serializerManager,
serializer,
closureSerializer,
cacheManager,
@@ -231,6 +236,7 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir,
metricsSystem,
- conf)
+ conf,
+ securityManager)
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
new file mode 100644
index 0000000000000..a2a871cbd3c31
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.IOException
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.RealmChoiceCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslClient
+import javax.security.sasl.SaslException
+
+import scala.collection.JavaConversions.mapAsJavaMap
+
+/**
+ * Implements SASL Client logic for Spark
+ */
+private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Used to respond to server's counterpart, SaslServer with SASL tokens
+ * represented as byte arrays.
+ *
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
+ null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslClientCallbackHandler(securityMgr))
+
+ /**
+ * Used to initiate SASL handshake with server.
+ * @return response to challenge if needed
+ */
+ def firstToken(): Array[Byte] = {
+ synchronized {
+ val saslToken: Array[Byte] =
+ if (saslClient != null && saslClient.hasInitialResponse()) {
+ logDebug("has initial response")
+ saslClient.evaluateChallenge(new Array[Byte](0))
+ } else {
+ new Array[Byte](0)
+ }
+ saslToken
+ }
+ }
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslClient != null) saslClient.isComplete() else false
+ }
+ }
+
+ /**
+ * Respond to server's SASL token.
+ * @param saslTokenMessage contains server's SASL token
+ * @return client's response SASL token
+ */
+ def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslClient might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslClient != null) {
+ try {
+ saslClient.dispose()
+ } catch {
+ case e: SaslException => // ignored
+ } finally {
+ saslClient = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * that works with share secrets.
+ */
+ private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
+ CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+ private val secretKey = securityMgr.getSecretKey()
+ private val userPassword: Array[Char] =
+ SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes())
+
+ /**
+ * Implementation used to respond to SASL request from the server.
+ *
+ * @param callbacks objects that indicate what credential information the
+ * server's SaslServer requires from the client.
+ */
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("in the sasl client callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL client callback: setting username: " + userName)
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL client callback: setting userPassword")
+ pc.setPassword(userPassword)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case cb: RealmChoiceCallback => {}
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
new file mode 100644
index 0000000000000..11fcb2ae3a5c5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import javax.security.auth.callback.Callback
+import javax.security.auth.callback.CallbackHandler
+import javax.security.auth.callback.NameCallback
+import javax.security.auth.callback.PasswordCallback
+import javax.security.auth.callback.UnsupportedCallbackException
+import javax.security.sasl.AuthorizeCallback
+import javax.security.sasl.RealmCallback
+import javax.security.sasl.Sasl
+import javax.security.sasl.SaslException
+import javax.security.sasl.SaslServer
+import scala.collection.JavaConversions.mapAsJavaMap
+import org.apache.commons.net.util.Base64
+
+/**
+ * Encapsulates SASL server logic
+ */
+private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
+
+ /**
+ * Actual SASL work done by this object from javax.security.sasl.
+ */
+ private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
+ SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
+ new SparkSaslDigestCallbackHandler(securityMgr))
+
+ /**
+ * Determines whether the authentication exchange has completed.
+ * @return true is complete, otherwise false
+ */
+ def isComplete(): Boolean = {
+ synchronized {
+ if (saslServer != null) saslServer.isComplete() else false
+ }
+ }
+
+ /**
+ * Used to respond to server SASL tokens.
+ * @param token Server's SASL token
+ * @return response to send back to the server.
+ */
+ def response(token: Array[Byte]): Array[Byte] = {
+ synchronized {
+ if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
+ }
+ }
+
+ /**
+ * Disposes of any system resources or security-sensitive information the
+ * SaslServer might be using.
+ */
+ def dispose() {
+ synchronized {
+ if (saslServer != null) {
+ try {
+ saslServer.dispose()
+ } catch {
+ case e: SaslException => // ignore
+ } finally {
+ saslServer = null
+ }
+ }
+ }
+ }
+
+ /**
+ * Implementation of javax.security.auth.callback.CallbackHandler
+ * for SASL DIGEST-MD5 mechanism
+ */
+ private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
+ extends CallbackHandler {
+
+ private val userName: String =
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes())
+
+ override def handle(callbacks: Array[Callback]) {
+ logDebug("In the sasl server callback handler")
+ callbacks foreach {
+ case nc: NameCallback => {
+ logDebug("handle: SASL server callback: setting username")
+ nc.setName(userName)
+ }
+ case pc: PasswordCallback => {
+ logDebug("handle: SASL server callback: setting userPassword")
+ val password: Array[Char] =
+ SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes())
+ pc.setPassword(password)
+ }
+ case rc: RealmCallback => {
+ logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
+ rc.setText(rc.getDefaultText())
+ }
+ case ac: AuthorizeCallback => {
+ val authid = ac.getAuthenticationID()
+ val authzid = ac.getAuthorizationID()
+ if (authid.equals(authzid)) {
+ logDebug("set auth to true")
+ ac.setAuthorized(true)
+ } else {
+ logDebug("set auth to false")
+ ac.setAuthorized(false)
+ }
+ if (ac.isAuthorized()) {
+ logDebug("sasl server is authorized")
+ ac.setAuthorizedID(authzid)
+ }
+ }
+ case cb: Callback => throw
+ new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
+ }
+ }
+ }
+}
+
+private[spark] object SparkSaslServer {
+
+ /**
+ * This is passed as the server name when creating the sasl client/server.
+ * This could be changed to be configurable in the future.
+ */
+ val SASL_DEFAULT_REALM = "default"
+
+ /**
+ * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
+ * configurable in the future.
+ */
+ val DIGEST = "DIGEST-MD5"
+
+ /**
+ * The quality of protection is just "auth". This means that we are doing
+ * authentication only, we are not supporting integrity or privacy protection of the
+ * communication channel after authentication. This could be changed to be configurable
+ * in the future.
+ */
+ val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
+
+ /**
+ * Encode a byte[] identifier as a Base64-encoded string.
+ *
+ * @param identifier identifier to encode
+ * @return Base64-encoded string
+ */
+ def encodeIdentifier(identifier: Array[Byte]): String = {
+ new String(Base64.encodeBase64(identifier))
+ }
+
+ /**
+ * Encode a password as a base64-encoded char[] array.
+ * @param password as a byte array.
+ * @return password as a char array.
+ */
+ def encodePassword(password: Array[Byte]): Array[Char] = {
+ new String(Base64.encodeBase64(password)).toCharArray()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index cae983ed4c652..be53ca2968cfb 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -46,6 +46,7 @@ class TaskContext(
}
def executeOnCompleteCallbacks() {
- onCompleteCallbacks.foreach{_()}
+ // Process complete callbacks in the reverse order of registration
+ onCompleteCallbacks.reverse.foreach{_()}
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 071044463d980..f816bb43a5b44 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -83,7 +83,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[JDouble, java.lang.Boolean]): JavaDoubleRDD =
- fromRDD(srdd.filter(x => f(x).booleanValue()))
+ fromRDD(srdd.filter(x => f.call(x).booleanValue()))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
@@ -140,6 +140,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
*/
def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd))
+
// Double RDD functions
/** Add up the elements in this RDD. */
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 3f672900cb90f..9596dbaf75488 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -26,13 +26,13 @@ import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job}
import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext.rddToPairRDDFunctions
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
-import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
import org.apache.spark.storage.StorageLevel
@@ -89,7 +89,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
- new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
+ new JavaPairRDD[K, V](rdd.filter(x => f.call(x).booleanValue()))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
@@ -126,6 +126,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.union(other.rdd))
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.intersection(other.rdd))
+
+
// first() has to be overridden here so that the generated method has the signature
// 'public scala.Tuple2 first()'; if the trait's definition is used,
// then the method has the signature 'public java.lang.Object first()',
@@ -165,9 +175,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Simplified version of combineByKey that hash-partitions the output RDD.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
- mergeValue: JFunction2[C, V, C],
- mergeCombiners: JFunction2[C, C, C],
- numPartitions: Int): JavaPairRDD[K, C] =
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ numPartitions: Int): JavaPairRDD[K, C] =
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
/**
@@ -442,7 +452,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
- def fn = (x: V) => f.apply(x).asScala
+ def fn = (x: V) => f.call(x).asScala
implicit val ctag: ClassTag[U] = fakeClassTag
fromRDD(rdd.flatMapValues(fn))
}
@@ -511,49 +521,57 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F],
- conf: JobConf) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: JobConf) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F]) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F],
- codec: Class[_ <: CompressionCodec]) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ codec: Class[_ <: CompressionCodec]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F],
- conf: Configuration) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: Configuration) {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
+ /**
+ * Output the RDD to any Hadoop-supported storage system, using
+ * a Configuration object for that storage system.
+ */
+ def saveAsNewAPIHadoopDataset(conf: Configuration) {
+ rdd.saveAsNewAPIHadoopDataset(conf)
+ }
+
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
- outputFormatClass: Class[F]) {
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
@@ -700,6 +718,15 @@ object JavaPairRDD {
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+ private[spark]
+ implicit def toScalaFunction2[T1, T2, R](fun: JFunction2[T1, T2, R]): Function2[T1, T2, R] = {
+ (x: T1, x1: T2) => fun.call(x, x1)
+ }
+
+ private[spark] implicit def toScalaFunction[T, R](fun: JFunction[T, R]): T => R = x => fun.call(x)
+
+ private[spark]
+ implicit def pairFunToScalaFun[A, B, C](x: PairFunction[A, B, C]): A => (B, C) = y => x.call(y)
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 0055c98844ded..01d9357a2556d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -70,7 +70,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
- wrapRDD(rdd.filter((x => f(x).booleanValue())))
+ wrapRDD(rdd.filter((x => f.call(x).booleanValue())))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
@@ -106,6 +106,15 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
*/
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
+
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd))
+
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 24a9925dbd22c..05b89b985736d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -19,7 +19,6 @@ package org.apache.spark.api.java
import java.util.{Comparator, List => JList}
-import scala.Tuple2
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -67,14 +66,14 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[R](f: JFunction[T, R]): JavaRDD[R] =
- new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
+ new JavaRDD(rdd.map(f)(fakeClassTag))(fakeClassTag)
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[R: ClassTag](
- f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]],
+ f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] =
new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
preservesPartitioning))
@@ -82,15 +81,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
- new JavaDoubleRDD(rdd.map(x => f(x).doubleValue()))
+ def mapToDouble[R](f: DoubleFunction[T]): JavaDoubleRDD = {
+ new JavaDoubleRDD(rdd.map(x => f.call(x).doubleValue()))
+ }
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
- val ctag = implicitly[ClassTag[Tuple2[K2, V2]]]
- new JavaPairRDD(rdd.map(f)(ctag))(f.keyType(), f.valueType())
+ def mapToPair[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
+ def cm = implicitly[ClassTag[(K2, V2)]]
+ new JavaPairRDD(rdd.map[(K2, V2)](f)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
/**
@@ -99,17 +99,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
import scala.collection.JavaConverters._
- def fn = (x: T) => f.apply(x).asScala
- JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType())
+ def fn = (x: T) => f.call(x).asScala
+ JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U])
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
+ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
import scala.collection.JavaConverters._
- def fn = (x: T) => f.apply(x).asScala
+ def fn = (x: T) => f.call(x).asScala
new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
}
@@ -117,19 +117,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
+ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
- def fn = (x: T) => f.apply(x).asScala
- val ctag = implicitly[ClassTag[Tuple2[K2, V2]]]
- JavaPairRDD.fromRDD(rdd.flatMap(fn)(ctag))(f.keyType(), f.valueType())
+ def fn = (x: T) => f.call(x).asScala
+ def cm = implicitly[ClassTag[(K2, V2)]]
+ JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaRDD.fromRDD(rdd.mapPartitions(fn)(fakeClassTag[U]))(fakeClassTag[U])
}
/**
@@ -137,52 +137,53 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U],
preservesPartitioning: Boolean): JavaRDD[U] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaRDD.fromRDD(rdd.mapPartitions(fn, preservesPartitioning)(f.elementType()))(f.elementType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaRDD.fromRDD(
+ rdd.mapPartitions(fn, preservesPartitioning)(fakeClassTag[U]))(fakeClassTag[U])
}
/**
- * Return a new RDD by applying a function to each partition of this RDD.
+ * Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
}
/**
- * Return a new RDD by applying a function to each partition of this RDD.
+ * Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
+ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
JavaPairRDD[K2, V2] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(fakeClassTag[K2], fakeClassTag[V2])
}
-
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]],
- preservesPartitioning: Boolean): JavaDoubleRDD = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]],
+ preservesPartitioning: Boolean): JavaDoubleRDD = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning)
- .map((x: java.lang.Double) => x.doubleValue()))
+ .map(x => x.doubleValue()))
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2],
+ def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2],
preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = {
- def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
- JavaPairRDD.fromRDD(rdd.mapPartitions(fn, preservesPartitioning))(f.keyType(), f.valueType())
+ def fn = (x: Iterator[T]) => asScalaIterator(f.call(asJavaIterator(x)).iterator())
+ JavaPairRDD.fromRDD(
+ rdd.mapPartitions(fn, preservesPartitioning))(fakeClassTag[K2], fakeClassTag[V2])
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) {
- rdd.foreachPartition((x => f(asJavaIterator(x))))
+ rdd.foreachPartition((x => f.call(asJavaIterator(x))))
}
/**
@@ -205,7 +206,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
implicit val ctagK: ClassTag[K] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(fakeClassTag)))
}
/**
@@ -215,7 +216,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
implicit val ctagK: ClassTag[K] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[K])))
}
/**
@@ -255,9 +256,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
other: JavaRDDLike[U, _],
f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = {
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
- f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
+ f.call(asJavaIterator(x), asJavaIterator(y)).iterator())
JavaRDD.fromRDD(
- rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType())
+ rdd.zipPartitions(other.rdd)(fn)(other.classTag, fakeClassTag[V]))(fakeClassTag[V])
}
// Actions (launch a job to return a value to the user program)
@@ -266,7 +267,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Applies a function f to all elements of this RDD.
*/
def foreach(f: VoidFunction[T]) {
- val cleanF = rdd.context.clean(f)
+ val cleanF = rdd.context.clean((x: T) => f.call(x))
rdd.foreach(cleanF)
}
@@ -281,7 +282,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Return an array that contains all of the elements in this RDD.
+ * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead
*/
+ @Deprecated
def toArray(): JList[T] = collect()
/**
@@ -320,7 +323,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
combOp: JFunction2[U, U, U]): U =
- rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType)
+ rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
/**
* Return the number of elements in the RDD.
@@ -475,6 +478,26 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
+ /**
+ * Returns the maximum element from this RDD as defined by the specified
+ * Comparator[T].
+ * @params comp the comparator that defines ordering
+ * @return the maximum of the RDD
+ * */
+ def max(comp: Comparator[T]): T = {
+ rdd.max()(Ordering.comparatorToOrdering(comp))
+ }
+
+ /**
+ * Returns the minimum element from this RDD as defined by the specified
+ * Comparator[T].
+ * @params comp the comparator that defines ordering
+ * @return the minimum of the RDD
+ * */
+ def min(comp: Comparator[T]): T = {
+ rdd.min()(Ordering.comparatorToOrdering(comp))
+ }
+
/**
* Returns the first K elements from this RDD using the
* natural ordering for T while maintain the order.
@@ -498,8 +521,4 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def name(): String = rdd.name
- /** Reset generator */
- def setGenerator(_generator: String) = {
- rdd.setGenerator(_generator)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index dc26b7f621fee..8e0eab56a3dcf 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -17,6 +17,7 @@
package org.apache.spark.api.java
+import java.util
import java.util.{Map => JMap}
import scala.collection.JavaConversions
@@ -92,6 +93,24 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
private[spark] val env = sc.env
+ def isLocal: java.lang.Boolean = sc.isLocal
+
+ def sparkUser: String = sc.sparkUser
+
+ def master: String = sc.master
+
+ def appName: String = sc.appName
+
+ def jars: util.List[String] = sc.jars
+
+ def startTime: java.lang.Long = sc.startTime
+
+ /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
+ def defaultParallelism: java.lang.Integer = sc.defaultParallelism
+
+ /** Default min number of partitions for Hadoop RDDs when not given by user */
+ def defaultMinSplits: java.lang.Integer = sc.defaultMinSplits
+
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala
deleted file mode 100644
index 2cdf2e92c3daa..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import java.lang.{Double => JDouble}
-
-/**
- * A function that returns Doubles, and can be used to construct DoubleRDDs.
- */
-// DoubleFunction does not extend Function because some UDF functions, like map,
-// are overloaded for both Function and DoubleFunction.
-abstract class DoubleFunction[T] extends WrappedFunction1[T, JDouble] with Serializable {
- // Intentionally left blank
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala
deleted file mode 100644
index 8467bbb892ab0..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.scala
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import java.lang.{Iterable => JIterable}
-import org.apache.spark.api.java.JavaSparkContext
-import scala.reflect.ClassTag
-
-/**
- * A function that returns zero or more key-value pair records from each input record. The
- * key-value pairs are represented as scala.Tuple2 objects.
- */
-// PairFlatMapFunction does not extend FlatMapFunction because flatMap is
-// overloaded for both FlatMapFunction and PairFlatMapFunction.
-abstract class PairFlatMapFunction[T, K, V] extends WrappedFunction1[T, JIterable[(K, V)]]
- with Serializable {
-
- def keyType(): ClassTag[K] = JavaSparkContext.fakeClassTag
-
- def valueType(): ClassTag[V] = JavaSparkContext.fakeClassTag
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala
deleted file mode 100644
index cfe694f65d558..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.runtime.AbstractFunction1
-
-/**
- * Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the
- * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
- * isn't marked to allow that).
- */
-private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
- @throws(classOf[Exception])
- def call(t: T): R
-
- final def apply(t: T): R = call(t)
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala
deleted file mode 100644
index eb9277c6fb4cb..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.runtime.AbstractFunction2
-
-/**
- * Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the
- * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
- * isn't marked to allow that).
- */
-private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
- @throws(classOf[Exception])
- def call(t1: T1, t2: T2): R
-
- final def apply(t1: T1, t2: T2): R = call(t1, t2)
-}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
deleted file mode 100644
index d314dbdf1d980..0000000000000
--- a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function
-
-import scala.runtime.AbstractFunction3
-
-/**
- * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
- * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
- * isn't marked to allow that).
- */
-private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
- extends AbstractFunction3[T1, T2, T3, R] {
- @throws(classOf[Exception])
- def call(t1: T1, t2: T2, t3: T3): R
-
- final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
-}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ca0e702addddd..43631e0e3bd49 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -103,6 +103,14 @@ private[spark] class PythonRDD[T: ClassTag](
}
}.start()
+ /*
+ * Partial fix for SPARK-1019: Attempts to stop reading the input stream since
+ * other completion callbacks might invalidate the input. Because interruption
+ * is not synchronous this still leaves a potential race where the interruption is
+ * processed only after the stream becomes invalid.
+ */
+ context.addOnCompleteCallback(() => context.interrupted = true)
+
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = new Iterator[Array[Byte]] {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index d113d4040594d..e3c3a12d16f2a 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable {
}
private[spark]
-class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable {
+class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
+ extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
@@ -78,7 +79,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isDriver, conf)
+ broadcastFactory.initialize(isDriver, conf, securityManager)
initialized = true
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 940e5ab805100..6beecaeced5be 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.broadcast
+import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
@@ -26,7 +27,7 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf): Unit
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 20207c261320b..e8eb04bb10469 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -18,13 +18,13 @@
package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
-import java.net.URL
+import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-import org.apache.spark.{HttpServer, Logging, SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
@@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ HttpBroadcast.initialize(isDriver, conf, securityMgr)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
@@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging {
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
+ private var securityManager: SecurityManager = null
// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
@@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging {
private var compressionCodec: CompressionCodec = null
- def initialize(isDriver: Boolean, conf: SparkConf) {
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
if (!initialized) {
bufferSize = conf.getInt("spark.buffer.size", 65536)
compress = conf.getBoolean("spark.broadcast.compress", true)
+ securityManager = securityMgr
if (isDriver) {
createServer(conf)
conf.set("spark.httpBroadcast.uri", serverUri)
@@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging {
private def createServer(conf: SparkConf) {
broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
- server = new HttpServer(broadcastDir)
+ server = new HttpServer(broadcastDir, securityManager)
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -149,11 +153,23 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
+ logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
+
+ var uc: URLConnection = null
+ if (securityManager.isAuthenticationEnabled()) {
+ logDebug("broadcast security enabled")
+ val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
+ uc = newuri.toURL().openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("broadcast not using security")
+ uc = new URL(url).openConnection()
+ }
+
val in = {
- val httpConnection = new URL(url).openConnection()
- httpConnection.setReadTimeout(httpReadTimeout)
- val inputStream = httpConnection.getInputStream
+ uc.setReadTimeout(httpReadTimeout)
+ val inputStream = uc.getInputStream();
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 22d783c8590c6..3cd71213769b7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -241,7 +241,9 @@ private[spark] case class TorrentInfo(
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ TorrentBroadcast.initialize(isDriver, conf)
+ }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index eb5676b51d836..d9e3035e1ab59 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -26,7 +26,7 @@ import akka.pattern.ask
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -141,7 +141,7 @@ object Client {
// TODO: See if we can initialize akka so return messages are sent back using the same TCP
// flow. Else, this (sadly) requires the DriverClient be routable from the Master.
val (actorSystem, _) = AkkaUtils.createActorSystem(
- "driverClient", Utils.localHostName(), 0, false, conf)
+ "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 190b331cfe7d8..f4eb1601be3e4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -27,22 +27,27 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.sys.process._
-import net.liftweb.json.JsonParser
+import org.json4s._
+import org.json4s.jackson.JsonMethods
-import org.apache.spark.{Logging, SparkContext}
-import org.apache.spark.deploy.master.RecoveryState
+import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil}
/**
* This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master.
* In order to mimic a real distributed cluster more closely, Docker is used.
* Execute using
- * ./spark-class org.apache.spark.deploy.FaultToleranceTest
+ * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest
*
- * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
+ * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS
+ * *and* SPARK_JAVA_OPTS:
* - spark.deploy.recoveryMode=ZOOKEEPER
* - spark.deploy.zookeeper.url=172.17.42.1:2181
* Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
*
+ * In case of failure, make sure to kill off prior docker containers before restarting:
+ * docker kill $(docker ps -q)
+ *
* Unfortunately, due to the Docker dependency this suite cannot be run automatically without a
* working installation of Docker. In addition to having Docker, the following are assumed:
* - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/)
@@ -50,10 +55,16 @@ import org.apache.spark.deploy.master.RecoveryState
* docker/ directory. Run 'docker/spark-test/build' to generate these.
*/
private[spark] object FaultToleranceTest extends App with Logging {
+
+ val conf = new SparkConf()
+ val ZK_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark")
+
val masters = ListBuffer[TestMasterInfo]()
val workers = ListBuffer[TestWorkerInfo]()
var sc: SparkContext = _
+ val zk = SparkCuratorUtil.newClient(conf)
+
var numPassed = 0
var numFailed = 0
@@ -71,6 +82,10 @@ private[spark] object FaultToleranceTest extends App with Logging {
sc = null
}
terminateCluster()
+
+ // Clear ZK directories in between tests (for speed purposes)
+ SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/spark_leader")
+ SparkCuratorUtil.deleteRecursive(zk, ZK_DIR + "/master_status")
}
test("sanity-basic") {
@@ -167,26 +182,34 @@ private[spark] object FaultToleranceTest extends App with Logging {
try {
fn
numPassed += 1
+ logInfo("==============================================")
logInfo("Passed: " + name)
+ logInfo("==============================================")
} catch {
case e: Exception =>
numFailed += 1
+ logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
logError("FAILED: " + name, e)
+ logInfo("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
+ sys.exit(1)
}
afterEach()
}
def addMasters(num: Int) {
+ logInfo(s">>>>> ADD MASTERS $num <<<<<")
(1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
}
def addWorkers(num: Int) {
+ logInfo(s">>>>> ADD WORKERS $num <<<<<")
val masterUrls = getMasterUrls(masters)
(1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
}
/** Creates a SparkContext, which constructs a Client to interact with our cluster. */
def createClient() = {
+ logInfo(">>>>> CREATE CLIENT <<<<<")
if (sc != null) { sc.stop() }
// Counter-hack: Because of a hack in SparkEnv#create() that changes this
// property, we need to reset it.
@@ -205,6 +228,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
}
def killLeader(): Unit = {
+ logInfo(">>>>> KILL LEADER <<<<<")
masters.foreach(_.readState())
val leader = getLeader
masters -= leader
@@ -214,6 +238,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
def terminateCluster() {
+ logInfo(">>>>> TERMINATE CLUSTER <<<<<")
masters.foreach(_.kill())
workers.foreach(_.kill())
masters.clear()
@@ -244,6 +269,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
* are all alive in a proper configuration (e.g., only one leader).
*/
def assertValidClusterState() = {
+ logInfo(">>>>> ASSERT VALID CLUSTER STATE <<<<<")
assertUsable()
var numAlive = 0
var numStandby = 0
@@ -311,7 +337,7 @@ private[spark] object FaultToleranceTest extends App with Logging {
private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File)
extends Logging {
- implicit val formats = net.liftweb.json.DefaultFormats
+ implicit val formats = org.json4s.DefaultFormats
var state: RecoveryState.Value = _
var liveWorkerIPs: List[String] = _
var numLiveApps = 0
@@ -321,11 +347,15 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val
def readState() {
try {
val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream)
- val json = JsonParser.parse(masterStream, closeAutomatically = true)
+ val json = JsonMethods.parse(masterStream)
val workers = json \ "workers"
val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE")
- liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String])
+ // Extract the worker IP from "webuiaddress" (rather than "host") because the host name
+ // on containers is a weird hash instead of the actual IP address.
+ liveWorkerIPs = liveWorkers.map {
+ w => (w \ "webuiaddress").extract[String].stripPrefix("http://").stripSuffix(":8081")
+ }
numLiveApps = (json \ "activeapps").children.size
@@ -349,7 +379,7 @@ private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val
private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File)
extends Logging {
- implicit val formats = net.liftweb.json.DefaultFormats
+ implicit val formats = org.json4s.DefaultFormats
logDebug("Created worker: " + this)
@@ -402,7 +432,7 @@ private[spark] object Docker extends Logging {
def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = {
val mountCmd = if (mountDir != "") { " -v " + mountDir } else ""
- val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args)
+ val cmd = "docker run -privileged %s %s %s".format(mountCmd, imageTag, args)
logDebug("Run command: " + cmd)
cmd
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index 318beb5db5214..cefb1ff97e83c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy
-import net.liftweb.json.JsonDSL._
+import org.json4s.JsonDSL._
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index b479225b45ee9..d2d8d6d662d55 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -21,10 +21,13 @@ import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{SparkContext, SparkException}
+import scala.collection.JavaConversions._
+
/**
* Contains util methods to interact with Hadoop from Spark.
*/
@@ -33,15 +36,9 @@ class SparkHadoopUtil {
UserGroupInformation.setConfiguration(conf)
def runAsUser(user: String)(func: () => Unit) {
- // if we are already running as the user intended there is no reason to do the doAs. It
- // will actually break secure HDFS access as it doesn't fill in the credentials. Also if
- // the user is UNKNOWN then we shouldn't be creating a remote unknown user
- // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
- // in SparkContext.
- val currentUser = Option(System.getProperty("user.name")).
- getOrElse(SparkContext.SPARK_UNKNOWN_USER)
- if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
+ if (user != SparkContext.SPARK_UNKNOWN_USER) {
val ugi = UserGroupInformation.createRemoteUser(user)
+ transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
ugi.doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit = func()
})
@@ -50,6 +47,12 @@ class SparkHadoopUtil {
}
}
+ def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
+ for (token <- source.getTokens()) {
+ dest.addToken(token)
+ }
+ }
+
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
* subsystems.
@@ -63,6 +66,15 @@ class SparkHadoopUtil {
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+
+ def getCurrentUserCredentials(): Credentials = { null }
+
+ def addCurrentUserCredentials(creds: Credentials) {}
+
+ def addSecretKeyToUserCredentials(key: String, secret: String) {}
+
+ def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
+
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index 1550c3eb4286b..63f166d401059 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.client
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -45,8 +45,9 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
+ val conf = new SparkConf
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0,
- conf = new SparkConf)
+ conf = conf, securityManager = new SecurityManager(conf))
val desc = new ApplicationDescription(
"TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()),
Some("dummy-spark-home"), "ignored")
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index f25a1ad3bf92a..a730fe1f599af 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -30,6 +30,7 @@ import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
* [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
*/
private[spark] trait LeaderElectionAgent extends Actor {
+ //TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
val masterActor: ActorRef
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 51794ce40cb45..b8dfa44102583 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -30,7 +30,7 @@ import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.DriverState.DriverState
@@ -39,7 +39,8 @@ import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
-private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+private[spark] class Master(host: String, port: Int, webUiPort: Int,
+ val securityMgr: SecurityManager) extends Actor with Logging {
import context.dispatcher // to use Akka's scheduler.schedule()
val conf = new SparkConf
@@ -70,8 +71,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
Utils.checkHost(host, "Expected hostname")
- val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf)
- val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf)
+ val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr)
+ val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf,
+ securityMgr)
val masterSource = new MasterSource(this)
val webUi = new MasterWebUI(this, webUiPort)
@@ -529,8 +531,15 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val workerAddress = worker.actor.path.address
if (addressToWorker.contains(workerAddress)) {
- logInfo("Attempted to re-register worker at same address: " + workerAddress)
- return false
+ val oldWorker = addressToWorker(workerAddress)
+ if (oldWorker.state == WorkerState.UNKNOWN) {
+ // A worker registering from UNKNOWN implies that the worker was restarted during recovery.
+ // The old worker must thus be dead, so we will remove it and accept the new worker.
+ removeWorker(oldWorker)
+ } else {
+ logInfo("Attempted to re-register worker at same address: " + workerAddress)
+ return false
+ }
}
workers += worker
@@ -711,8 +720,11 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf)
: (ActorSystem, Int, Int) =
{
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf)
- val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
+ val securityMgr = new SecurityManager(conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
+ securityManager = securityMgr)
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
+ securityMgr), actorName)
val timeout = AkkaUtils.askTimeout(conf)
val respFuture = actor.ask(RequestWebUIPort)(timeout)
val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
index 74a9f8cd824fb..db72d8ae9bdaf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -28,10 +28,6 @@ private[master] object MasterMessages {
case object RevokedLeadership
- // Actor System to LeaderElectionAgent
-
- case object CheckLeader
-
// Actor System to Master
case object CheckForWorkerTimeOut
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala
new file mode 100644
index 0000000000000..4781a80d470e1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import scala.collection.JavaConversions._
+
+import org.apache.curator.framework.{CuratorFramework, CuratorFrameworkFactory}
+import org.apache.curator.retry.ExponentialBackoffRetry
+import org.apache.zookeeper.KeeperException
+
+import org.apache.spark.{Logging, SparkConf}
+
+object SparkCuratorUtil extends Logging {
+
+ val ZK_CONNECTION_TIMEOUT_MILLIS = 15000
+ val ZK_SESSION_TIMEOUT_MILLIS = 60000
+ val RETRY_WAIT_MILLIS = 5000
+ val MAX_RECONNECT_ATTEMPTS = 3
+
+ def newClient(conf: SparkConf): CuratorFramework = {
+ val ZK_URL = conf.get("spark.deploy.zookeeper.url")
+ val zk = CuratorFrameworkFactory.newClient(ZK_URL,
+ ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS,
+ new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS))
+ zk.start()
+ zk
+ }
+
+ def mkdir(zk: CuratorFramework, path: String) {
+ if (zk.checkExists().forPath(path) == null) {
+ try {
+ zk.create().creatingParentsIfNeeded().forPath(path)
+ } catch {
+ case nodeExist: KeeperException.NodeExistsException =>
+ // do nothing, ignore node existing exception.
+ case e: Exception => throw e
+ }
+ }
+ }
+
+ def deleteRecursive(zk: CuratorFramework, path: String) {
+ if (zk.checkExists().forPath(path) != null) {
+ for (child <- zk.getChildren.forPath(path)) {
+ zk.delete().forPath(path + "/" + child)
+ }
+ zk.delete().forPath(path)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
deleted file mode 100644
index 57758055b19c0..0000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.master
-
-import scala.collection.JavaConversions._
-
-import org.apache.zookeeper._
-import org.apache.zookeeper.Watcher.Event.KeeperState
-import org.apache.zookeeper.data.Stat
-
-import org.apache.spark.{Logging, SparkConf}
-
-/**
- * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
- * logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be
- * created. If ZooKeeper remains down after several retries, the given
- * [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be
- * informed via zkDown().
- *
- * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
- * times or a semantic exception is thrown (e.g., "node already exists").
- */
-private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher,
- conf: SparkConf) extends Logging {
- val ZK_URL = conf.get("spark.deploy.zookeeper.url", "")
-
- val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
- val ZK_TIMEOUT_MILLIS = 30000
- val RETRY_WAIT_MILLIS = 5000
- val ZK_CHECK_PERIOD_MILLIS = 10000
- val MAX_RECONNECT_ATTEMPTS = 3
-
- private var zk: ZooKeeper = _
-
- private val watcher = new ZooKeeperWatcher()
- private var reconnectAttempts = 0
- private var closed = false
-
- /** Connect to ZooKeeper to start the session. Must be called before anything else. */
- def connect() {
- connectToZooKeeper()
-
- new Thread() {
- override def run() = sessionMonitorThread()
- }.start()
- }
-
- def sessionMonitorThread(): Unit = {
- while (!closed) {
- Thread.sleep(ZK_CHECK_PERIOD_MILLIS)
- if (zk.getState != ZooKeeper.States.CONNECTED) {
- reconnectAttempts += 1
- val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts
- if (attemptsLeft <= 0) {
- logError("Could not connect to ZooKeeper: system failure")
- zkWatcher.zkDown()
- close()
- } else {
- logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...")
- connectToZooKeeper()
- }
- }
- }
- }
-
- def close() {
- if (!closed && zk != null) { zk.close() }
- closed = true
- }
-
- private def connectToZooKeeper() {
- if (zk != null) zk.close()
- zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher)
- }
-
- /**
- * Attempts to maintain a live ZooKeeper exception despite (very) transient failures.
- * Mainly useful for handling the natural ZooKeeper session expiration.
- */
- private class ZooKeeperWatcher extends Watcher {
- def process(event: WatchedEvent) {
- if (closed) { return }
-
- event.getState match {
- case KeeperState.SyncConnected =>
- reconnectAttempts = 0
- zkWatcher.zkSessionCreated()
- case KeeperState.Expired =>
- connectToZooKeeper()
- case KeeperState.Disconnected =>
- logWarning("ZooKeeper disconnected, will retry...")
- case s => // Do nothing
- }
- }
- }
-
- def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = {
- retry {
- zk.create(path, bytes, ZK_ACL, createMode)
- }
- }
-
- def exists(path: String, watcher: Watcher = null): Stat = {
- retry {
- zk.exists(path, watcher)
- }
- }
-
- def getChildren(path: String, watcher: Watcher = null): List[String] = {
- retry {
- zk.getChildren(path, watcher).toList
- }
- }
-
- def getData(path: String): Array[Byte] = {
- retry {
- zk.getData(path, false, null)
- }
- }
-
- def delete(path: String, version: Int = -1): Unit = {
- retry {
- zk.delete(path, version)
- }
- }
-
- /**
- * Creates the given directory (non-recursively) if it doesn't exist.
- * All znodes are created in PERSISTENT mode with no data.
- */
- def mkdir(path: String) {
- if (exists(path) == null) {
- try {
- create(path, "".getBytes, CreateMode.PERSISTENT)
- } catch {
- case e: Exception =>
- // If the exception caused the directory not to be created, bubble it up,
- // otherwise ignore it.
- if (exists(path) == null) { throw e }
- }
- }
- }
-
- /**
- * Recursively creates all directories up to the given one.
- * All znodes are created in PERSISTENT mode with no data.
- */
- def mkdirRecursive(path: String) {
- var fullDir = ""
- for (dentry <- path.split("/").tail) {
- fullDir += "/" + dentry
- mkdir(fullDir)
- }
- }
-
- /**
- * Retries the given function up to 3 times. The assumption is that failure is transient,
- * UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist),
- * in which case the exception will be thrown without retries.
- *
- * @param fn Block to execute, possibly multiple times.
- */
- def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = {
- try {
- fn
- } catch {
- case e: KeeperException.NoNodeException => throw e
- case e: KeeperException.NodeExistsException => throw e
- case e: Exception if n > 0 =>
- logError("ZooKeeper exception, " + n + " more retries...", e)
- Thread.sleep(RETRY_WAIT_MILLIS)
- retry(fn, n-1)
- }
- }
-}
-
-trait SparkZooKeeperWatcher {
- /**
- * Called whenever a ZK session is created --
- * this will occur when we create our first session as well as each time
- * the session expires or errors out.
- */
- def zkSessionCreated()
-
- /**
- * Called if ZK appears to be completely down (i.e., not just a transient error).
- * We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead.
- */
- def zkDown()
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 47b8f67f8a45b..285f9b014e291 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -18,105 +18,67 @@
package org.apache.spark.deploy.master
import akka.actor.ActorRef
-import org.apache.zookeeper._
-import org.apache.zookeeper.Watcher.Event.EventType
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.master.MasterMessages._
+import org.apache.curator.framework.CuratorFramework
+import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
masterUrl: String, conf: SparkConf)
- extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
+ extends LeaderElectionAgent with LeaderLatchListener with Logging {
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
- private val watcher = new ZooKeeperWatcher()
- private val zk = new SparkZooKeeperSession(this, conf)
+ private var zk: CuratorFramework = _
+ private var leaderLatch: LeaderLatch = _
private var status = LeadershipStatus.NOT_LEADER
- private var myLeaderFile: String = _
- private var leaderUrl: String = _
override def preStart() {
+
logInfo("Starting ZooKeeper LeaderElection agent")
- zk.connect()
- }
+ zk = SparkCuratorUtil.newClient(conf)
+ leaderLatch = new LeaderLatch(zk, WORKING_DIR)
+ leaderLatch.addListener(this)
- override def zkSessionCreated() {
- synchronized {
- zk.mkdirRecursive(WORKING_DIR)
- myLeaderFile =
- zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL)
- self ! CheckLeader
- }
+ leaderLatch.start()
}
override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
- logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason)
- Thread.sleep(zk.ZK_TIMEOUT_MILLIS)
+ logError("LeaderElectionAgent failed...", reason)
super.preRestart(reason, message)
}
- override def zkDown() {
- logError("ZooKeeper down! LeaderElectionAgent shutting down Master.")
- System.exit(1)
- }
-
override def postStop() {
+ leaderLatch.close()
zk.close()
}
override def receive = {
- case CheckLeader => checkLeader()
+ case _ =>
}
- private class ZooKeeperWatcher extends Watcher {
- def process(event: WatchedEvent) {
- if (event.getType == EventType.NodeDeleted) {
- logInfo("Leader file disappeared, a master is down!")
- self ! CheckLeader
+ override def isLeader() {
+ synchronized {
+ // could have lost leadership by now.
+ if (!leaderLatch.hasLeadership) {
+ return
}
- }
- }
- /** Uses ZK leader election. Navigates several ZK potholes along the way. */
- def checkLeader() {
- val masters = zk.getChildren(WORKING_DIR).toList
- val leader = masters.sorted.head
- val leaderFile = WORKING_DIR + "/" + leader
-
- // Setup a watch for the current leader.
- zk.exists(leaderFile, watcher)
-
- try {
- leaderUrl = new String(zk.getData(leaderFile))
- } catch {
- // A NoNodeException may be thrown if old leader died since the start of this method call.
- // This is fine -- just check again, since we're guaranteed to see the new values.
- case e: KeeperException.NoNodeException =>
- logInfo("Leader disappeared while reading it -- finding next leader")
- checkLeader()
- return
+ logInfo("We have gained leadership")
+ updateLeadershipStatus(true)
}
+ }
- // Synchronization used to ensure no interleaving between the creation of a new session and the
- // checking of a leader, which could cause us to delete our real leader file erroneously.
+ override def notLeader() {
synchronized {
- val isLeader = myLeaderFile == leaderFile
- if (!isLeader && leaderUrl == masterUrl) {
- // We found a different master file pointing to this process.
- // This can happen in the following two cases:
- // (1) The master process was restarted on the same node.
- // (2) The ZK server died between creating the file and returning the name of the file.
- // For this case, we will end up creating a second file, and MUST explicitly delete the
- // first one, since our ZK session is still open.
- // Note that this deletion will cause a NodeDeleted event to be fired so we check again for
- // leader changes.
- assert(leaderFile < myLeaderFile)
- logWarning("Cleaning up old ZK master election file that points to this master.")
- zk.delete(leaderFile)
- } else {
- updateLeadershipStatus(isLeader)
+ // could have gained leadership by now.
+ if (leaderLatch.hasLeadership) {
+ return
}
+
+ logInfo("We have lost leadership")
+ updateLeadershipStatus(false)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 48b2fc06a9d70..5413ff671ad8d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,36 +17,28 @@
package org.apache.spark.deploy.master
+import scala.collection.JavaConversions._
+
import akka.serialization.Serialization
-import org.apache.zookeeper._
+import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
extends PersistenceEngine
- with SparkZooKeeperWatcher
with Logging
{
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
+ val zk = SparkCuratorUtil.newClient(conf)
- val zk = new SparkZooKeeperSession(this, conf)
-
- zk.connect()
-
- override def zkSessionCreated() {
- zk.mkdirRecursive(WORKING_DIR)
- }
-
- override def zkDown() {
- logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.")
- }
+ SparkCuratorUtil.mkdir(zk, WORKING_DIR)
override def addApplication(app: ApplicationInfo) {
serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
}
override def removeApplication(app: ApplicationInfo) {
- zk.delete(WORKING_DIR + "/app_" + app.id)
+ zk.delete().forPath(WORKING_DIR + "/app_" + app.id)
}
override def addDriver(driver: DriverInfo) {
@@ -54,7 +46,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
}
override def removeDriver(driver: DriverInfo) {
- zk.delete(WORKING_DIR + "/driver_" + driver.id)
+ zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id)
}
override def addWorker(worker: WorkerInfo) {
@@ -62,7 +54,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
}
override def removeWorker(worker: WorkerInfo) {
- zk.delete(WORKING_DIR + "/worker_" + worker.id)
+ zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id)
}
override def close() {
@@ -70,26 +62,34 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
}
override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
+ val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted
val appFiles = sortedFiles.filter(_.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten
val driverFiles = sortedFiles.filter(_.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
+ val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten
val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten
(apps, drivers, workers)
}
private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
- zk.create(path, serialized, CreateMode.PERSISTENT)
+ zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
}
- def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
- val fileData = zk.getData(WORKING_DIR + "/" + filename)
+ def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = {
+ val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
- serializer.fromBinary(fileData).asInstanceOf[T]
+ try {
+ Some(serializer.fromBinary(fileData).asInstanceOf[T])
+ } catch {
+ case e: Exception => {
+ logWarning("Exception while reading persisted file, deleting", e)
+ zk.delete().forPath(WORKING_DIR + "/" + filename)
+ None
+ }
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 5cc4adbe448b7..90cad3c37fda6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -23,7 +23,8 @@ import scala.concurrent.Await
import scala.xml.Node
import akka.pattern.ask
-import net.liftweb.json.JsonAST.JValue
+import javax.servlet.http.HttpServletRequest
+import org.json4s.JValue
import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
index 01c8f9065e50a..3233cd97f7bd0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
@@ -23,7 +23,8 @@ import scala.concurrent.Await
import scala.xml.Node
import akka.pattern.ask
-import net.liftweb.json.JsonAST.JValue
+import javax.servlet.http.HttpServletRequest
+import org.json4s.JValue
import org.apache.spark.deploy.{DeployWebUI, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
@@ -85,6 +86,7 @@ private[spark] class IndexPage(parent: MasterWebUI) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 5ab13e7aa6b1f..4ad1f95be31c9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -18,8 +18,8 @@
package org.apache.spark.deploy.master.ui
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.master.Master
@@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get))
@@ -60,12 +60,17 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++
master.applicationMetricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)),
- ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)),
- ("/app", (request: HttpServletRequest) => applicationPage.render(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR + "/static", "/static"),
+ createServletHandler("/app/json",
+ createServlet((request: HttpServletRequest) => applicationPage.renderJson(request),
+ master.securityMgr)),
+ createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage
+ .render(request), master.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), master.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), master.securityMgr))
)
def stop() {
@@ -74,5 +79,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
}
private[spark] object MasterWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index a26e47950a0ec..be15138f62406 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker
import akka.actor._
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.util.{AkkaUtils, Utils}
/**
@@ -29,8 +29,9 @@ object DriverWrapper {
def main(args: Array[String]) {
args.toList match {
case workerUrl :: mainClass :: extraArgs =>
+ val conf = new SparkConf()
val (actorSystem, _) = AkkaUtils.createActorSystem("Driver",
- Utils.localHostName(), 0, false, new SparkConf())
+ Utils.localHostName(), 0, false, conf, new SecurityManager(conf))
actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher")
// Delegate to supplied main class
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 7b0b7861b76e1..afaabedffefea 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -27,7 +27,7 @@ import scala.concurrent.duration._
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
@@ -48,7 +48,8 @@ private[spark] class Worker(
actorSystemName: String,
actorName: String,
workDirPath: String = null,
- val conf: SparkConf)
+ val conf: SparkConf,
+ val securityMgr: SecurityManager)
extends Actor with Logging {
import context.dispatcher
@@ -91,7 +92,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
- val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)
def coresFree: Int = cores - coresUsed
@@ -347,10 +348,11 @@ private[spark] object Worker {
val conf = new SparkConf
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val actorName = "Worker"
+ val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
- conf = conf)
+ conf = conf, securityManager = securityMgr)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrls, systemName, actorName, workDir, conf), name = actorName)
+ masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
index 3089acffb8d98..85200ab0e102d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
@@ -22,7 +22,7 @@ import scala.xml.Node
import akka.pattern.ask
import javax.servlet.http.HttpServletRequest
-import net.liftweb.json.JsonAST.JValue
+import org.json4s.JValue
import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index bdf126f93abc8..4e33b330ad4e7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker.ui
import java.io.File
import javax.servlet.http.HttpServletRequest
-
-import org.eclipse.jetty.server.{Handler, Server}
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.Logging
import org.apache.spark.deploy.worker.Worker
@@ -33,7 +33,7 @@ import org.apache.spark.util.{AkkaUtils, Utils}
*/
private[spark]
class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
- extends Logging {
+ extends Logging {
val timeout = AkkaUtils.askTimeout(worker.conf)
val host = Utils.localHostName()
val port = requestedPort.getOrElse(
@@ -46,17 +46,21 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val metricsHandlers = worker.metricsSystem.getServletHandlers
- val handlers = metricsHandlers ++ Array[(String, Handler)](
- ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)),
- ("/log", (request: HttpServletRequest) => log(request)),
- ("/logPage", (request: HttpServletRequest) => logPage(request)),
- ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
- ("*", (request: HttpServletRequest) => indexPage.render(request))
+ val handlers = metricsHandlers ++ Seq[ServletContextHandler](
+ createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE + "/static", "/static"),
+ createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request),
+ worker.securityMgr)),
+ createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage
+ (request), worker.securityMgr)),
+ createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage
+ .renderJson(request), worker.securityMgr)),
+ createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render
+ (request), worker.securityMgr))
)
def start() {
try {
- val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers)
+ val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf)
server = Some(srv)
boundPort = Some(bPort)
logInfo("Started Worker web UI at http://%s:%d".format(host, bPort))
@@ -198,6 +202,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
}
private[spark] object WorkerWebUI {
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_BASE = "org/apache/spark/ui"
val DEFAULT_PORT="8081"
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 0aae569b17272..3486092a140fb 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import akka.actor._
import akka.remote._
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, Logging}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend {
// Debug code
Utils.checkHost(hostname)
+ val conf = new SparkConf
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
- indestructible = true, conf = new SparkConf)
+ indestructible = true, conf = conf, new SecurityManager(conf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 989d666f15600..2ea2ec29f59f5 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
@@ -69,11 +69,6 @@ private[spark] class Executor(
conf.set("spark.local.dir", getYarnLocalDirs())
}
- // Create our ClassLoader and set it on this thread
- private val urlClassLoader = createClassLoader()
- private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
- Thread.currentThread.setContextClassLoader(replClassLoader)
-
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
@@ -117,11 +112,15 @@ private[spark] class Executor(
}
}
+ // Create our ClassLoader and set it on this thread
+ // do this after SparkEnv creation so can access the SecurityManager
+ private val urlClassLoader = createClassLoader()
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
- private val akkaFrameSize = {
- env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
- }
+ private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
@@ -338,12 +337,12 @@ private[spark] class Executor(
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 455339943f42d..760458cb02a9b 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -103,13 +103,6 @@ class ShuffleReadMetrics extends Serializable {
*/
var fetchWaitTime: Long = _
- /**
- * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
- * input blocks. Since block fetches are both pipelined and parallelized, this can
- * exceed fetchWaitTime and executorRunTime.
- */
- var remoteFetchTime: Long = _
-
/**
* Total number of remote bytes read from the shuffle by this task
*/
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 966c092124266..c5bda2078fc14 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
import org.apache.spark.metrics.source.Source
@@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source
* [options] is the specific property of this source or sink.
*/
private[spark] class MetricsSystem private (val instance: String,
- conf: SparkConf) extends Logging {
+ conf: SparkConf, securityMgr: SecurityManager) extends Logging {
val confFile = conf.get("spark.metrics.conf", null)
val metricsConfig = new MetricsConfig(Option(confFile))
@@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String,
val classPath = kv._2.getProperty("class")
try {
val sink = Class.forName(classPath)
- .getConstructor(classOf[Properties], classOf[MetricRegistry])
- .newInstance(kv._2, registry)
+ .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
} else {
@@ -160,6 +160,7 @@ private[spark] object MetricsSystem {
}
}
- def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem =
- new MetricsSystem(instance, conf)
+ def createMetricsSystem(instance: String, conf: SparkConf,
+ securityMgr: SecurityManager): MetricsSystem =
+ new MetricsSystem(instance, conf, securityMgr)
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
index 98fa1dbd7c6ab..4d2ffc54d8983 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -22,9 +22,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{ConsoleReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class ConsoleSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CONSOLE_DEFAULT_PERIOD = 10
val CONSOLE_DEFAULT_UNIT = "SECONDS"
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
index 40f64768e6885..319f40815d65f 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -23,9 +23,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.{CsvReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class CsvSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val CSV_KEY_PERIOD = "period"
val CSV_KEY_UNIT = "unit"
val CSV_KEY_DIR = "directory"
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index e09be001421fc..0ffdf3846dc4a 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -24,9 +24,11 @@ import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.graphite.{Graphite, GraphiteReporter}
+import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
-class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+class GraphiteSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val GRAPHITE_DEFAULT_PERIOD = 10
val GRAPHITE_DEFAULT_UNIT = "SECONDS"
val GRAPHITE_DEFAULT_PREFIX = ""
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
index b5cf210af2119..3b5edd5c376f0 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -20,8 +20,11 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import com.codahale.metrics.{JmxReporter, MetricRegistry}
+import org.apache.spark.SecurityManager
+
+class JmxSink(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
-class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink {
val reporter: JmxReporter = JmxReporter.forRegistry(registry).build()
override def start() {
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
index 3cdfe26d40f66..3110eccdee4fc 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -19,16 +19,19 @@ package org.apache.spark.metrics.sink
import java.util.Properties
import java.util.concurrent.TimeUnit
+
import javax.servlet.http.HttpServletRequest
import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.json.MetricsModule
import com.fasterxml.jackson.databind.ObjectMapper
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
+import org.apache.spark.SecurityManager
import org.apache.spark.ui.JettyUtils
-class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink {
+class MetricsServlet(val property: Properties, val registry: MetricRegistry,
+ securityMgr: SecurityManager) extends Sink {
val SERVLET_KEY_PATH = "path"
val SERVLET_KEY_SAMPLE = "sample"
@@ -42,8 +45,11 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext
val mapper = new ObjectMapper().registerModule(
new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample))
- def getHandlers = Array[(String, Handler)](
- (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json"))
+ def getHandlers = Array[ServletContextHandler](
+ JettyUtils.createServletHandler(servletPath,
+ JettyUtils.createServlet(
+ new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"),
+ securityMgr) )
)
def getMetricsSnapshot(request: HttpServletRequest): String = {
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
index d3c09b16063d6..04df2f3b0d696 100644
--- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -45,9 +45,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Max chunk size is " + maxChunkSize)
}
+ val security = if (isSecurityNeg) 1 else 0
if (size == 0 && !gotChunkForSendingOnce) {
val newChunk = new MessageChunk(
- new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -65,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
@@ -79,6 +80,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
+ val security = if (isSecurityNeg) 1 else 0
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
@@ -86,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
- typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
return Some(newChunk)
}
None
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index f2e3c1a14ecc6..8fd9c2b87d256 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -17,6 +17,11 @@
package org.apache.spark.network
+import org.apache.spark._
+import org.apache.spark.SparkSaslServer
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
import java.net._
import java.nio._
import java.nio.channels._
@@ -27,13 +32,16 @@ import org.apache.spark._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
- val socketRemoteConnectionManagerId: ConnectionManagerId)
+ val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
extends Logging {
- def this(channel_ : SocketChannel, selector_ : Selector) = {
+ var sparkSaslServer: SparkSaslServer = null
+ var sparkSaslClient: SparkSaslClient = null
+
+ def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
- channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_)
}
channel.configureBlocking(false)
@@ -49,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
+ /**
+ * Used to synchronize client requests: client's work-related requests must
+ * wait until SASL authentication completes.
+ */
+ private val authenticated = new Object()
+
+ def getAuthenticated(): Object = authenticated
+
+ def isSaslComplete(): Boolean
+
def resetForceReregister(): Boolean
// Read channels typically do not register for write and write does not for read
@@ -69,6 +87,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
// Will be true for ReceivingConnection, false for SendingConnection.
def changeInterestForRead(): Boolean
+ private def disposeSasl() {
+ if (sparkSaslServer != null) {
+ sparkSaslServer.dispose();
+ }
+
+ if (sparkSaslClient != null) {
+ sparkSaslClient.dispose()
+ }
+ }
+
// On receiving a write event, should we change the interest for this channel or not ?
// Will be false for ReceivingConnection, true for SendingConnection.
// Actually, for now, should not get triggered for ReceivingConnection
@@ -101,6 +129,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
k.cancel()
}
channel.close()
+ disposeSasl()
callOnCloseCallback()
}
@@ -168,10 +197,14 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[spark]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
- remoteId_ : ConnectionManagerId)
- extends Connection(SocketChannel.open, selector_, remoteId_) {
+ remoteId_ : ConnectionManagerId, id_ : ConnectionId)
+ extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
- private class Outbox(fair: Int = 0) {
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
+ }
+
+ private class Outbox {
val messages = new Queue[Message]()
val defaultChunkSize = 65536 //32768 //16384
var nextMessageToBeUsed = 0
@@ -186,38 +219,6 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
}
def getChunk(): Option[MessageChunk] = {
- fair match {
- case 0 => getChunkFIFO()
- case 1 => getChunkRR()
- case _ => throw new Exception("Unexpected fairness policy in outbox")
- }
- }
-
- private def getChunkFIFO(): Option[MessageChunk] = {
- /*logInfo("Using FIFO")*/
- messages.synchronized {
- while (!messages.isEmpty) {
- val message = messages(0)
- val chunk = message.getChunkForSending(defaultChunkSize)
- if (chunk.isDefined) {
- messages += message // this is probably incorrect, it wont work as fifo
- if (!message.started) {
- logDebug("Starting to send [" + message + "]")
- message.started = true
- message.startTime = System.currentTimeMillis
- }
- return chunk
- } else {
- message.finishTime = System.currentTimeMillis
- logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
- "] in " + message.timeTaken )
- }
- }
- }
- None
- }
-
- private def getChunkRR(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
/*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
@@ -249,7 +250,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// outbox is used as a lock - ensure that it is always used as a leaf (since methods which
// lock it are invoked in context of other locks)
- private val outbox = new Outbox(1)
+ private val outbox = new Outbox()
/*
This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
different purpose. This flag is to see if we need to force reregister for write even when we
@@ -258,6 +259,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
data as detailed in https://github.com/mesos/spark/pull/791
*/
private var needForceReregister = false
+
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
@@ -348,6 +350,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// If we have 'seen' pending messages, then reset flag - since we handle that as
// normal registering of event (below)
if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+
currentBuffers ++= buffers
}
case None => {
@@ -416,8 +419,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Must be created within selector loop - else deadlock
-private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
- extends Connection(channel_, selector_) {
+private[spark] class ReceivingConnection(
+ channel_ : SocketChannel,
+ selector_ : Selector,
+ id_ : ConnectionId)
+ extends Connection(channel_, selector_, id_) {
+
+ def isSaslComplete(): Boolean = {
+ if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
+ }
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
@@ -428,6 +438,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
+ newMessage.isSecurityNeg = header.securityNeg == 1
logDebug(
"Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
@@ -473,7 +484,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
- var onReceiveCallback: (Connection , Message) => Unit = null
+ var onReceiveCallback: (Connection, Message) => Unit = null
var currentChunk: MessageChunk = null
channel.register(selector, SelectionKey.OP_READ)
@@ -548,7 +559,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S
}
}
} catch {
- case e: Exception => {
+ case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
similarity index 55%
rename from core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala
rename to core/src/main/scala/org/apache/spark/network/ConnectionId.scala
index ea94313a4ab59..ffaab677d411a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala
@@ -15,19 +15,20 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package org.apache.spark.network
-/**
- * A function with no return value.
- */
-// This allows Java users to write void methods without having to return Unit.
-abstract class VoidFunction[T] extends Serializable {
- @throws(classOf[Exception])
- def call(t: T) : Unit
+private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) {
+ override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId
}
-// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly
-// return Unit), so it is implicitly converted to a Function1[T, Unit]:
-object VoidFunction {
- implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x))
+private[spark] object ConnectionId {
+
+ def createConnectionIdFromString(connectionIdString: String): ConnectionId = {
+ val res = connectionIdString.split("_").map(_.trim())
+ if (res.size != 3) {
+ throw new Exception("Error converting ConnectionId string: " + connectionIdString +
+ " to a ConnectionId Object")
+ }
+ new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 3dd82bee0b5fd..a75130cba2a2e 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -21,6 +21,9 @@ import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.atomic.AtomicInteger
+
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
import scala.collection.mutable.ArrayBuffer
@@ -28,13 +31,15 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
+
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.concurrent.duration._
import org.apache.spark._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SystemClock, Utils}
-private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging {
+private[spark] class ConnectionManager(port: Int, conf: SparkConf,
+ securityManager: SecurityManager) extends Logging {
class MessageStatus(
val message: Message,
@@ -50,6 +55,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private val selector = SelectorProvider.provider.openSelector()
+ // default to 30 second timeout waiting for authentication
+ private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
+
private val handleMessageExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.handler.threads.min", 20),
conf.getInt("spark.core.connection.handler.threads.max", 60),
@@ -71,6 +79,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
+ // used to track the SendingConnections waiting to do SASL negotiation
+ private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection]
+ with SynchronizedMap[ConnectionId, SendingConnection]
private val connectionsByKey =
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
@@ -84,6 +95,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
@@ -94,6 +107,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+ // used in combination with the ConnectionManagerId to create unique Connection ids
+ // to be able to track asynchronous messages
+ private val idCount: AtomicInteger = new AtomicInteger(1)
+
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
@@ -125,7 +142,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
} finally {
writeRunnableStarted.synchronized {
writeRunnableStarted -= key
- val needReregister = register || conn.resetForceReregister()
+ val needReregister = register || conn.resetForceReregister()
if (needReregister && conn.changeInterestForWrite()) {
conn.registerInterest()
}
@@ -372,7 +389,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
// accept them all in a tight loop. non blocking accept with no processing, should be fine
while (newChannel != null) {
try {
- val newConnection = new ReceivingConnection(newChannel, selector)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@@ -406,6 +424,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
+ connectionsAwaitingSasl -= connection.connectionId
messageStatuses.synchronized {
messageStatuses
@@ -481,7 +500,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
val creationTime = System.currentTimeMillis
def run() {
logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message)
+ handleMessage(connectionManagerId, message, connection)
logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
}
}
@@ -489,10 +508,133 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
/*handleMessage(connection, message)*/
}
- private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ private def handleClientAuthentication(
+ waitingConn: SendingConnection,
+ securityMsg: SecurityMessage,
+ connectionId : ConnectionId) {
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll();
+ }
+ return
+ } else {
+ var replyToken : Array[Byte] = null
+ try {
+ replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken);
+ if (waitingConn.isSaslComplete()) {
+ logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
+ connectionsAwaitingSasl -= waitingConn.connectionId
+ waitingConn.getAuthenticated().synchronized {
+ waitingConn.getAuthenticated().notifyAll()
+ }
+ return
+ }
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId.toString())
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
+ } catch {
+ case e: Exception => {
+ logError("Error handling sasl client authentication", e)
+ waitingConn.close()
+ throw new Exception("Error evaluating sasl response: " + e)
+ }
+ }
+ }
+ }
+
+ private def handleServerAuthentication(
+ connection: Connection,
+ securityMsg: SecurityMessage,
+ connectionId: ConnectionId) {
+ if (!connection.isSaslComplete()) {
+ logDebug("saslContext not established")
+ var replyToken : Array[Byte] = null
+ try {
+ connection.synchronized {
+ if (connection.sparkSaslServer == null) {
+ logDebug("Creating sasl Server")
+ connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ }
+ }
+ replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
+ if (connection.isSaslComplete()) {
+ logDebug("Server sasl completed: " + connection.connectionId)
+ } else {
+ logDebug("Server sasl not completed: " + connection.connectionId)
+ }
+ if (replyToken != null) {
+ var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ securityMsg.getConnectionId)
+ var message = securityMsgResp.toBufferMessage
+ if (message == null) throw new Exception("Error creating security Message")
+ sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
+ }
+ } catch {
+ case e: Exception => {
+ logError("Error in server auth negotiation: " + e)
+ // It would probably be better to send an error message telling other side auth failed
+ // but for now just close
+ connection.close()
+ }
+ }
+ } else {
+ logDebug("connection already established for this connection id: " + connection.connectionId)
+ }
+ }
+
+
+ private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = {
+ if (bufferMessage.isSecurityNeg) {
+ logDebug("This is security neg message")
+
+ // parse as SecurityMessage
+ val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage)
+ val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId)
+
+ connectionsAwaitingSasl.get(connectionId) match {
+ case Some(waitingConn) => {
+ // Client - this must be in response to us doing Send
+ logDebug("Client handleAuth for id: " + waitingConn.connectionId)
+ handleClientAuthentication(waitingConn, securityMsg, connectionId)
+ }
+ case None => {
+ // Server - someone sent us something and we haven't authenticated yet
+ logDebug("Server handleAuth for id: " + connectionId)
+ handleServerAuthentication(conn, securityMsg, connectionId)
+ }
+ }
+ return true
+ } else {
+ if (!conn.isSaslComplete()) {
+ // We could handle this better and tell the client we need to do authentication
+ // negotiation, but for now just ignore them.
+ logError("message sent that is not security negotiation message on connection " +
+ "not authenticated yet, ignoring it!!")
+ return true
+ }
+ }
+ return false
+ }
+
+ private def handleMessage(
+ connectionManagerId: ConnectionManagerId,
+ message: Message,
+ connection: Connection) {
logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
message match {
case bufferMessage: BufferMessage => {
+ if (authEnabled) {
+ val res = handleAuthentication(connection, bufferMessage)
+ if (res == true) {
+ // message was security negotiation so skip the rest
+ logDebug("After handleAuth result was true, returning")
+ return
+ }
+ }
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
@@ -541,20 +683,124 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
}
}
+ private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) {
+ // see if we need to do sasl before writing
+ // this should only be the first negotiation as the Client!!!
+ if (!conn.isSaslComplete()) {
+ conn.synchronized {
+ if (conn.sparkSaslClient == null) {
+ conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ var firstResponse: Array[Byte] = null
+ try {
+ firstResponse = conn.sparkSaslClient.firstToken()
+ var securityMsg = SecurityMessage.fromResponse(firstResponse,
+ conn.connectionId.toString())
+ var message = securityMsg.toBufferMessage
+ if (message == null) throw new Exception("Error creating security message")
+ connectionsAwaitingSasl += ((conn.connectionId, conn))
+ sendSecurityMessage(connManagerId, message)
+ logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
+ } catch {
+ case e: Exception => {
+ logError("Error getting first response from the SaslClient.", e)
+ conn.close()
+ throw new Exception("Error getting first response from the SaslClient")
+ }
+ }
+ }
+ }
+ } else {
+ logDebug("Sasl already established ")
+ }
+ }
+
+ // allow us to add messages to the inbox for doing sasl negotiating
+ private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
+ newConnectionId)
+ logInfo("creating new sending connection for security! " + newConnectionId )
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ?
+ // We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ message.senderAddress = id.toSocketAddress()
+ logTrace("Sending Security [" + message + "] to [" + connManagerId + "]")
+ val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection())
+
+ //send security message until going connection has been authenticated
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host,
connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
+ newConnectionId)
+ logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
- // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it
- // useful in our test-env ... If we do re-add it, we should consistently use it everywhere I
- // guess ?
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
+ if (authEnabled) {
+ checkSendAuthFirst(connectionManagerId, connection)
+ }
message.senderAddress = id.toSocketAddress()
+ logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
+ "connectionid: " + connection.connectionId)
+
+ if (authEnabled) {
+ // if we aren't authenticated yet lets block the senders until authentication completes
+ try {
+ connection.getAuthenticated().synchronized {
+ val clock = SystemClock
+ val startTime = clock.getTime()
+
+ while (!connection.isSaslComplete()) {
+ logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
+ // have timeout in case remote side never responds
+ connection.getAuthenticated().wait(500)
+ if (((clock.getTime() - startTime) >= (authTimeout * 1000))
+ && (!connection.isSaslComplete())) {
+ // took to long to authenticate the connection, something probably went wrong
+ throw new Exception("Took to long for authentication to " + connectionManagerId +
+ ", waited " + authTimeout + "seconds, failing.")
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Exception while waiting for authentication.", e)
+
+ // need to tell sender it failed
+ messageStatuses.synchronized {
+ val s = messageStatuses.get(message.id)
+ s match {
+ case Some(msgStatus) => {
+ messageStatuses -= message.id
+ logInfo("Notifying " + msgStatus.connectionManagerId)
+ msgStatus.synchronized {
+ msgStatus.attempted = true
+ msgStatus.acked = false
+ msgStatus.markDone()
+ }
+ }
+ case None => {
+ logError("no messageStatus for failed message id: " + message.id)
+ }
+ }
+ }
+ }
+ }
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
@@ -606,7 +852,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi
private[spark] object ConnectionManager {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
index 20fe67661844f..7caccfdbb44f9 100644
--- a/core/src/main/scala/org/apache/spark/network/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -27,6 +27,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var started = false
var startTime = -1L
var finishTime = -1L
+ var isSecurityNeg = false
def size: Int
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
index 9bcbc6141a502..ead663ede7a1c 100644
--- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
val totalSize: Int,
val chunkSize: Int,
val other: Int,
+ val securityNeg: Int,
val address: InetSocketAddress) {
lazy val buffer = {
// No need to change this, at 'use' time, we do a reverse lookup of the hostname.
@@ -40,6 +41,7 @@ private[spark] class MessageChunkHeader(
putInt(totalSize).
putInt(chunkSize).
putInt(other).
+ putInt(securityNeg).
putInt(ip.size).
put(ip).
putInt(port).
@@ -48,12 +50,13 @@ private[spark] class MessageChunkHeader(
}
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
- " and sizes " + totalSize + " / " + chunkSize + " bytes"
+ " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg
+
}
private[spark] object MessageChunkHeader {
- val HEADER_SIZE = 40
+ val HEADER_SIZE = 44
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
@@ -64,11 +67,13 @@ private[spark] object MessageChunkHeader {
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
+ val securityNeg = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
- new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
+ new InetSocketAddress(ip, port))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
index 9976255c7e251..3c09a713c6fe0 100644
--- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -18,12 +18,12 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object ReceiverTest {
def main(args: Array[String]) {
- val manager = new ConnectionManager(9999, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(9999, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
new file mode 100644
index 0000000000000..0d9f743b3624b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.StringBuilder
+
+import org.apache.spark._
+import org.apache.spark.network._
+
+/**
+ * SecurityMessage is class that contains the connectionId and sasl token
+ * used in SASL negotiation. SecurityMessage has routines for converting
+ * it to and from a BufferMessage so that it can be sent by the ConnectionManager
+ * and easily consumed by users when received.
+ * The api was modeled after BlockMessage.
+ *
+ * The connectionId is the connectionId of the client side. Since
+ * message passing is asynchronous and its possible for the server side (receiving)
+ * to get multiple different types of messages on the same connection the connectionId
+ * is used to know which connnection the security message is intended for.
+ *
+ * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side
+ * is acting as a client and connecting to node_1. SASL negotiation has to occur
+ * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message.
+ * node_1 receives the message from node_0 but before it can process it and send a response,
+ * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0
+ * and sends a security message of its own to authenticate as a client. Now node_0 gets
+ * the message and it needs to decide if this message is in response to it being a client
+ * (from the first send) or if its just node_1 trying to connect to it to send data. This
+ * is where the connectionId field is used. node_0 can lookup the connectionId to see if
+ * it is in response to it being a client or if its in response to someone sending other data.
+ *
+ * The format of a SecurityMessage as its sent is:
+ * - Length of the ConnectionId
+ * - ConnectionId
+ * - Length of the token
+ * - Token
+ */
+private[spark] class SecurityMessage() extends Logging {
+
+ private var connectionId: String = null
+ private var token: Array[Byte] = null
+
+ def set(byteArr: Array[Byte], newconnectionId: String) {
+ if (byteArr == null) {
+ token = new Array[Byte](0)
+ } else {
+ token = byteArr
+ }
+ connectionId = newconnectionId
+ }
+
+ /**
+ * Read the given buffer and set the members of this class.
+ */
+ def set(buffer: ByteBuffer) {
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ connectionId = idBuilder.toString()
+
+ val tokenLength = buffer.getInt()
+ token = new Array[Byte](tokenLength)
+ if (tokenLength > 0) {
+ buffer.get(token, 0, tokenLength)
+ }
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getConnectionId: String = {
+ return connectionId
+ }
+
+ def getToken: Array[Byte] = {
+ return token
+ }
+
+ /**
+ * Create a BufferMessage that can be sent by the ConnectionManager containing
+ * the security information from this class.
+ * @return BufferMessage
+ */
+ def toBufferMessage: BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ // 4 bytes for the length of the connectionId
+ // connectionId is of type char so multiple the length by 2 to get number of bytes
+ // 4 bytes for the length of token
+ // token is a byte buffer so just take the length
+ var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length)
+ buffer.putInt(connectionId.length())
+ connectionId.foreach((x: Char) => buffer.putChar(x))
+ buffer.putInt(token.length)
+
+ if (token.length > 0) {
+ buffer.put(token)
+ }
+ buffer.flip()
+ buffers += buffer
+
+ var message = Message.createBufferMessage(buffers)
+ logDebug("message total size is : " + message.size)
+ message.isSecurityNeg = true
+ return message
+ }
+
+ override def toString: String = {
+ "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]"
+ }
+}
+
+private[spark] object SecurityMessage {
+
+ /**
+ * Convert the given BufferMessage to a SecurityMessage by parsing the contents
+ * of the BufferMessage and populating the SecurityMessage fields.
+ * @param bufferMessage is a BufferMessage that was received
+ * @return new SecurityMessage
+ */
+ def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(bufferMessage)
+ newSecurityMessage
+ }
+
+ /**
+ * Create a SecurityMessage to send from a given saslResponse.
+ * @param response is the response to a challenge from the SaslClient or Saslserver
+ * @param connectionId the client connectionId we are negotiation authentication for
+ * @return a new SecurityMessage
+ */
+ def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = {
+ val newSecurityMessage = new SecurityMessage()
+ newSecurityMessage.set(response, connectionId)
+ newSecurityMessage
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
index 646f8425d9551..aac2c24a46faa 100644
--- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -18,8 +18,7 @@
package org.apache.spark.network
import java.nio.ByteBuffer
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
private[spark] object SenderTest {
def main(args: Array[String]) {
@@ -32,8 +31,8 @@ private[spark] object SenderTest {
val targetHost = args(0)
val targetPort = args(1).toInt
val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
-
- val manager = new ConnectionManager(0, new SparkConf)
+ val conf = new SparkConf
+ val manager = new ConnectionManager(0, conf, new SecurityManager(conf))
println("Started connection manager with id = " + manager.id)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 699a10c96c227..8561711931047 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
+import org.apache.spark.serializer.Serializer
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -66,10 +67,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): CoGroupedRDD[K] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
+ this.serializer = serializer
this
}
@@ -80,7 +81,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[Any, Any](rdd, part, serializerClass)
+ new ShuffleDependency[Any, Any](rdd, part, serializer)
}
}
}
@@ -113,18 +114,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
// A list of (rdd iterator, dependency number) pairs
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
// Read them from the parent
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
- }
- case ShuffleCoGroupSplitDep(shuffleId) => {
+
+ case ShuffleCoGroupSplitDep(shuffleId) =>
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
+ val ser = Serializer.getSerializer(serializer)
val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
rddIterators += ((it, depNum))
- }
}
if (!externalSorting) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index a374fc4a871b0..100ddb360732a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -18,8 +18,10 @@
package org.apache.spark.rdd
import java.io.EOFException
+import scala.collection.immutable.Map
import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.mapred.FileSplit
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.InputSplit
import org.apache.hadoop.mapred.JobConf
@@ -43,6 +45,23 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
override def hashCode(): Int = 41 * (41 + rddId) + idx
override val index: Int = idx
+
+ /**
+ * Get any environment variables that should be added to the users environment when running pipes
+ * @return a Map with the environment variables and corresponding values, it could be empty
+ */
+ def getPipeEnvVars(): Map[String, String] = {
+ val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
+ val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
+ // map_input_file is deprecated in favor of mapreduce_map_input_file but set both
+ // since its not removed yet
+ Map("map_input_file" -> is.getPath().toString(),
+ "mapreduce_map_input_file" -> is.getPath().toString())
+ } else {
+ Map()
+ }
+ envVars
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index d29a1a9881cd4..447deafff53cd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -30,23 +30,21 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
-import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
-import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
// SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark.
import org.apache.hadoop.mapred.SparkHadoopWriter
-import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.partial.{BoundedDouble, PartialResult}
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.SerializableHyperLogLog
/**
@@ -76,7 +74,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true,
- serializerClass: String = null): RDD[(K, C)] = {
+ serializer: Serializer = null): RDD[(K, C)] = {
require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
if (getKeyClass().isArray) {
if (mapSideCombine) {
@@ -96,13 +94,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
aggregator.combineValuesByKey(iter, context)
}, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
- .setSerializer(serializerClass)
+ .setSerializer(serializer)
partitioned.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
- val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
+ val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
@@ -196,6 +194,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
/** Alias for reduceByKeyLocally */
+ @deprecated("Use reduceByKeyLocally", "1.0.0")
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
/** Count the number of elements for each key, and return the result to the master as a Map. */
@@ -425,7 +424,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* Return the key-value pairs in this RDD to the master as a Map.
*/
def collectAsMap(): Map[K, V] = {
- val data = self.toArray()
+ val data = self.collect()
val map = new mutable.HashMap[K, V]
map.sizeHint(data.length)
data.foreach { case (k, v) => map.put(k, v) }
@@ -604,46 +603,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
- val wrappedConf = new SerializableWritable(job.getConfiguration)
- NewFileOutputFormat.setOutputPath(job, new Path(path))
- val formatter = new SimpleDateFormat("yyyyMMddHHmm")
- val jobtrackerID = formatter.format(new Date())
- val stageId = self.id
- def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = {
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- /* "reduce task" */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- attemptNumber)
- val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
- val format = outputFormatClass.newInstance
- format match {
- case c: Configurable => c.setConf(wrappedConf.value)
- case _ => ()
- }
- val committer = format.getOutputCommitter(hadoopContext)
- committer.setupTask(hadoopContext)
- val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
- while (iter.hasNext) {
- val (k, v) = iter.next()
- writer.write(k, v)
- }
- writer.close(hadoopContext)
- committer.commitTask(hadoopContext)
- return 1
- }
- val jobFormat = outputFormatClass.newInstance
- /* apparently we need a TaskAttemptID to construct an OutputCommitter;
- * however we're only going to use this local OutputCommitter for
- * setupJob/commitJob, so we just use a dummy "map" task.
- */
- val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0)
- val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
- val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
- jobCommitter.setupJob(jobTaskContext)
- val count = self.context.runJob(self, writeShard _).sum
- jobCommitter.commitJob(jobTaskContext)
+ job.setOutputFormatClass(outputFormatClass)
+ job.getConfiguration.set("mapred.output.dir", path)
+ saveAsNewAPIHadoopDataset(job.getConfiguration)
}
/**
@@ -689,6 +651,59 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
saveAsHadoopDataset(conf)
}
+ /**
+ * Output the RDD to any Hadoop-supported storage system with new Hadoop API, using a Hadoop
+ * Configuration object for that storage system. The Conf should set an OutputFormat and any
+ * output paths required (e.g. a table name to write to) in the same way as it would be
+ * configured for a Hadoop MapReduce job.
+ */
+ def saveAsNewAPIHadoopDataset(conf: Configuration) {
+ val job = new NewAPIHadoopJob(conf)
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ val jobtrackerID = formatter.format(new Date())
+ val stageId = self.id
+ val wrappedConf = new SerializableWritable(job.getConfiguration)
+ val outfmt = job.getOutputFormatClass
+ val jobFormat = outfmt.newInstance
+
+ if (jobFormat.isInstanceOf[NewFileOutputFormat[_, _]]) {
+ // FileOutputFormat ignores the filesystem parameter
+ jobFormat.checkOutputSpecs(job)
+ }
+
+ def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = {
+ // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+ // around by taking a mod. We expect that no task will be attempted 2 billion times.
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ /* "reduce task" */
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
+ attemptNumber)
+ val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ val format = outfmt.newInstance
+ format match {
+ case c: Configurable => c.setConf(wrappedConf.value)
+ case _ => ()
+ }
+ val committer = format.getOutputCommitter(hadoopContext)
+ committer.setupTask(hadoopContext)
+ val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
+ while (iter.hasNext) {
+ val (k, v) = iter.next()
+ writer.write(k, v)
+ }
+ writer.close(hadoopContext)
+ committer.commitTask(hadoopContext)
+ return 1
+ }
+
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0)
+ val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
+ val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
+ jobCommitter.setupJob(jobTaskContext)
+ self.context.runJob(self, writeShard _)
+ jobCommitter.commitJob(jobTaskContext)
+ }
+
/**
* Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
* that storage system. The JobConf should set an OutputFormat and any output paths required
@@ -696,10 +711,10 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* MapReduce job.
*/
def saveAsHadoopDataset(conf: JobConf) {
- val outputFormatClass = conf.getOutputFormat
+ val outputFormatInstance = conf.getOutputFormat
val keyClass = conf.getOutputKeyClass
val valueClass = conf.getOutputValueClass
- if (outputFormatClass == null) {
+ if (outputFormatInstance == null) {
throw new SparkException("Output format class not set")
}
if (keyClass == null) {
@@ -712,6 +727,12 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
+ if (outputFormatInstance.isInstanceOf[FileOutputFormat[_, _]]) {
+ // FileOutputFormat ignores the filesystem parameter
+ val ignoredFs = FileSystem.get(conf)
+ conf.getOutputFormat.checkOutputSpecs(ignoredFs, conf)
+ }
+
val writer = new SparkHadoopWriter(conf)
writer.preSetup()
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index abd4414e81f5c..4250a9d02f764 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -28,6 +28,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkEnv, TaskContext}
+
/**
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
@@ -59,6 +60,13 @@ class PipedRDD[T: ClassTag](
val currentEnvVars = pb.environment()
envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
+ // for compatibility with Hadoop which sets these env variables
+ // so the user code can access the input filename
+ if (split.isInstanceOf[HadoopPartition]) {
+ val hadoopSplit = split.asInstanceOf[HadoopPartition]
+ currentEnvVars.putAll(hadoopSplit.getPipeEnvVars())
+ }
+
val proc = pb.start()
val env = SparkEnv.get
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 50320f40350cd..ddb901246d360 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -126,14 +126,6 @@ abstract class RDD[T: ClassTag](
this
}
- /** User-defined generator of this RDD*/
- @transient var generator = Utils.getCallSiteInfo.firstUserClass
-
- /** Reset generator*/
- def setGenerator(_generator: String) = {
- generator = _generator
- }
-
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. This can only be used to assign a new storage level if the RDD does not
@@ -318,6 +310,7 @@ abstract class RDD[T: ClassTag](
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
+ require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
} else {
@@ -352,6 +345,10 @@ abstract class RDD[T: ClassTag](
throw new IllegalArgumentException("Negative number of elements requested")
}
+ if (initialCount == 0) {
+ return new Array[T](0)
+ }
+
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
@@ -370,7 +367,7 @@ abstract class RDD[T: ClassTag](
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
// If the first sample didn't turn out large enough, keep trying to take samples;
- // this shouldn't happen often because we use a big multiplier for thei initial size
+ // this shouldn't happen often because we use a big multiplier for the initial size
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
}
@@ -543,7 +540,8 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassTag, U: ClassTag]
+ @deprecated("use mapPartitionsWithIndex", "1.0.0")
+ def mapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
mapPartitionsWithIndex((index, iter) => {
@@ -557,7 +555,8 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassTag, U: ClassTag]
+ @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0")
+ def flatMapWith[A, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
mapPartitionsWithIndex((index, iter) => {
@@ -571,7 +570,8 @@ abstract class RDD[T: ClassTag](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
+ @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0")
+ def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit) {
mapPartitionsWithIndex { (index, iter) =>
val a = constructA(index)
iter.map(t => {f(t, a); t})
@@ -583,7 +583,8 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ @deprecated("use mapPartitionsWithIndex and filter", "1.0.0")
+ def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
mapPartitionsWithIndex((index, iter) => {
val a = constructA(index)
iter.filter(t => p(t, a))
@@ -662,6 +663,7 @@ abstract class RDD[T: ClassTag](
/**
* Return an array that contains all of the elements in this RDD.
*/
+ @deprecated("use collect", "1.0.0")
def toArray(): Array[T] = collect()
/**
@@ -954,6 +956,18 @@ abstract class RDD[T: ClassTag](
*/
def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse)
+ /**
+ * Returns the max of this RDD as defined by the implicit Ordering[T].
+ * @return the maximum element of the RDD
+ * */
+ def max()(implicit ord: Ordering[T]):T = this.reduce(ord.max)
+
+ /**
+ * Returns the min of this RDD as defined by the implicit Ordering[T].
+ * @return the minimum element of the RDD
+ * */
+ def min()(implicit ord: Ordering[T]):T = this.reduce(ord.min)
+
/**
* Save this RDD as a text file, using string representations of elements.
*/
@@ -1027,8 +1041,9 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
- /** Record user function generating this RDD. */
- @transient private[spark] val origin = sc.getCallSite()
+ /** User code that created this RDD (e.g. `textFile`, `parallelize`). */
+ @transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
+ private[spark] def getCreationSite = Utils.formatCallSiteInfo(creationSiteInfo)
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
@@ -1091,10 +1106,7 @@ abstract class RDD[T: ClassTag](
}
override def toString: String = "%s%s[%d] at %s".format(
- Option(name).map(_ + " ").getOrElse(""),
- getClass.getSimpleName,
- id,
- origin)
+ Option(name).map(_ + " ").getOrElse(""), getClass.getSimpleName, id, getCreationSite)
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index b50307cfa49b7..4ceea557f569c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -26,13 +26,13 @@ import cern.jet.random.engine.DRand
import org.apache.spark.{Partition, TaskContext}
-@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0")
+@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0")
private[spark]
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
override val index: Int = prev.index
}
-@deprecated("Replaced by PartitionwiseSampledRDD", "1.0")
+@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0")
class SampledRDD[T: ClassTag](
prev: RDD[T],
withReplacement: Boolean,
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 0bbda25a905cd..02660ea6a45c5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
+import org.apache.spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -38,15 +39,15 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): ShuffledRDD[K, V, P] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
+ this.serializer = serializer
this
}
override def getDependencies: Seq[Dependency[_]] = {
- List(new ShuffleDependency(prev, part, serializerClass))
+ List(new ShuffleDependency(prev, part, serializer))
}
override val partitioner = Some(part)
@@ -57,8 +58,8 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
- SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf))
+ val ser = Serializer.getSerializer(serializer)
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 5fe9f363db453..9a09c05bbc959 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -30,6 +30,7 @@ import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
+import org.apache.spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -53,10 +54,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializerClass: String = null
+ private var serializer: Serializer = null
- def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
- serializerClass = cls
+ def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
+ this.serializer = serializer
this
}
@@ -67,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency(rdd, part, serializerClass)
+ new ShuffleDependency(rdd, part, serializer)
}
}
}
@@ -92,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
+ val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -105,14 +106,13 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
}
def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
- case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
- }
- case ShuffleCoGroupSplitDep(shuffleId) => {
+
+ case ShuffleCoGroupSplitDep(shuffleId) =>
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context, serializer)
+ context, ser)
iter.foreach(op)
- }
}
// the first dep is rdd1; add all values to the map
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dc5b25d845dc2..d83d0341c61ab 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -279,7 +279,7 @@ class DAGScheduler(
} else {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
}
stage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 9d75d7c4ad69a..01cbcc390c6cd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -81,7 +81,7 @@ class JobLogger(val user: String, val logDirName: String)
/**
* Create a log file for one job
* @param jobID ID of the job
- * @exception FileNotFoundException Fail to create log file
+ * @throws FileNotFoundException Fail to create log file
*/
protected def createLogWriter(jobID: Int) {
try {
@@ -213,14 +213,10 @@ class JobLogger(val user: String, val logDirName: String)
* @param indent Indent number before info
*/
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val cacheStr = if (rdd.getStorageLevel != StorageLevel.NONE) "CACHED" else "NONE"
val rddInfo =
- if (rdd.getStorageLevel != StorageLevel.NONE) {
- "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
- rdd.origin + " " + rdd.generator
- } else {
- "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
- rdd.origin + " " + rdd.generator
- }
+ s"RDD_ID=$rdd.id ${getRddName(rdd)} $cacheStr " +
+ s"${rdd.getCreationSite} ${rdd.creationSiteInfo.firstUserClass}"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
rdd.dependencies.foreach {
case shufDep: ShuffleDependency[_, _] =>
@@ -275,7 +271,6 @@ class JobLogger(val user: String, val logDirName: String)
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index eefc8c232b564..f1924a4573b21 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
/**
* A backend interface for scheduling systems that allows plugging in different ones under
- * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
+ * TaskSchedulerImpl. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them.
*/
private[spark] trait SchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 77789031f464a..2a9edf4a76b97 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -26,6 +26,7 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
@@ -153,7 +154,7 @@ private[spark] class ShuffleMapTask(
try {
// Obtain all the block writers for shuffle blocks.
- val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf)
+ val ser = Serializer.getSerializer(dep.serializer)
shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index a78b0186b9eab..5c1fc30e4a557 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -100,7 +100,7 @@ private[spark] class Stage(
id
}
- val name = callSite.getOrElse(rdd.origin)
+ val name = callSite.getOrElse(rdd.getCreationSite)
override def toString = "Stage " + id
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 1cdfed1d7005e..92616c997e20c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -20,7 +20,7 @@ package org.apache.spark.scheduler
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler.
+ * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
* This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
* for a single SparkContext. These schedulers get sets of tasks submitted to them from the
* DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 8df37c247d0d4..23b06612fd7ab 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -25,6 +25,7 @@ import scala.concurrent.duration._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import scala.util.Random
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@@ -207,9 +208,11 @@ private[spark] class TaskSchedulerImpl(
}
}
- // Build a list of tasks to assign to each worker
- val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
- val availableCpus = offers.map(o => o.cores).toArray
+ // Randomly shuffle offers to avoid always placing tasks on the same set of workers.
+ val shuffledOffers = Random.shuffle(offers)
+ // Build a list of tasks to assign to each worker.
+ val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+ val availableCpus = shuffledOffers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue()
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
@@ -222,9 +225,9 @@ private[spark] class TaskSchedulerImpl(
for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
do {
launchedTask = false
- for (i <- 0 until offers.size) {
- val execId = offers(i).executorId
- val host = offers(i).host
+ for (i <- 0 until shuffledOffers.size) {
+ val execId = shuffledOffers(i).executorId
+ val host = shuffledOffers(i).host
for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
tasks(i) += task
val tid = task.taskId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 1a4b7e599c01e..5ea4557bbf56a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -26,13 +26,14 @@ import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
-import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
+import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted,
+ SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.{Clock, SystemClock}
/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
+ * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
* each task, retries tasks if they fail (up to a limited number of times), and
* handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
* to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
@@ -41,7 +42,7 @@ import org.apache.spark.util.{Clock, SystemClock}
* THREADING: This class is designed to only be called from code with a lock on the
* TaskScheduler (e.g. its event handlers). It should not be called from other threads.
*
- * @param sched the ClusterScheduler associated with the TaskSetManager
+ * @param sched the TaskSchedulerImpl associated with the TaskSetManager
* @param taskSet the TaskSet to manage scheduling for
* @param maxTaskFailures if any particular task fails more than this number of times, the entire
* task set will be aborted
diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
index ba6bab3f91a65..810b36cddf835 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -21,4 +21,4 @@ package org.apache.spark.scheduler
* Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val executorId: String, val host: String, val cores: Int)
+case class WorkerOffer(executorId: String, host: String, cores: Int)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 379e02eb9a437..fad03731572e7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -54,6 +54,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
private val executorAddress = new HashMap[String, Address]
private val executorHost = new HashMap[String, String]
private val freeCores = new HashMap[String, Int]
+ private val totalCores = new HashMap[String, Int]
private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
@@ -76,6 +77,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
sender ! RegisteredExecutor(sparkProperties)
executorActor(executorId) = sender
executorHost(executorId) = Utils.parseHostPort(hostPort)._1
+ totalCores(executorId) = cores
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
addressToExecutorId(sender.path.address) = executorId
@@ -147,10 +149,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
def removeExecutor(executorId: String, reason: String) {
if (executorActor.contains(executorId)) {
logInfo("Executor " + executorId + " disconnected, so removing it")
- val numCores = freeCores(executorId)
- addressToExecutorId -= executorAddress(executorId)
+ val numCores = totalCores(executorId)
executorActor -= executorId
executorHost -= executorId
+ addressToExecutorId -= executorAddress(executorId)
+ executorAddress -= executorId
+ totalCores -= executorId
freeCores -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.executorLost(executorId, SlaveLost(reason))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index c576beb0c0d38..bcf0ce19a54cd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -203,7 +203,7 @@ private[spark] class MesosSchedulerBackend(
getResource(offer.getResourcesList, "cpus").toInt)
}
- // Call into the ClusterScheduler
+ // Call into the TaskSchedulerImpl
val taskLists = scheduler.resourceOffers(offerableWorkers)
// Build a list of Mesos tasks for each slave
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 50f7e79e97dd8..16e2f5cf3076d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -35,7 +35,7 @@ private case class KillTask(taskId: Long)
/**
* Calls to LocalBackend are all serialized through LocalActor. Using an actor makes the calls on
* LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend
- * and the ClusterScheduler.
+ * and the TaskSchedulerImpl.
*/
private[spark] class LocalActor(
scheduler: TaskSchedulerImpl,
@@ -76,7 +76,7 @@ private[spark] class LocalActor(
/**
* LocalBackend is used when running a local version of Spark where the executor, backend, and
- * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks
+ * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks
* on a single Executor (created by the LocalBackend) running locally.
*/
private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 33c1705ad7c58..18a68b05fa853 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -23,16 +23,34 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.util.ByteBufferInputStream
-private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
- val objOut = new ObjectOutputStream(out)
- def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
+private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
+ extends SerializationStream {
+ private val objOut = new ObjectOutputStream(out)
+ private var counter = 0
+
+ /**
+ * Calling reset to avoid memory leak:
+ * http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
+ * But only call it every 10,000th time to avoid bloated serialization streams (when
+ * the stream 'resets' object class descriptions have to be re-written)
+ */
+ def writeObject[T](t: T): SerializationStream = {
+ objOut.writeObject(t)
+ if (counterReset > 0 && counter >= counterReset) {
+ objOut.reset()
+ counter = 0
+ } else {
+ counter += 1
+ }
+ this
+ }
def flush() { objOut.flush() }
def close() { objOut.close() }
}
private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
- val objIn = new ObjectInputStream(in) {
+ private val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, loader)
}
@@ -41,7 +59,7 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-private[spark] class JavaSerializerInstance extends SerializerInstance {
+private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -63,7 +81,7 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
}
def serializeStream(s: OutputStream): SerializationStream = {
- new JavaSerializationStream(s)
+ new JavaSerializationStream(s, counterReset)
}
def deserializeStream(s: InputStream): DeserializationStream = {
@@ -78,6 +96,16 @@ private[spark] class JavaSerializerInstance extends SerializerInstance {
/**
* A Spark serializer that uses Java's built-in serialization.
*/
-class JavaSerializer(conf: SparkConf) extends Serializer {
- def newInstance(): SerializerInstance = new JavaSerializerInstance
+class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
+ private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000)
+
+ def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeInt(counterReset)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ counterReset = in.readInt()
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 920490f9d0d61..6b6d814c1fe92 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -34,10 +34,14 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
-class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging {
- private val bufferSize = {
- conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
- }
+class KryoSerializer(conf: SparkConf)
+ extends org.apache.spark.serializer.Serializer
+ with Logging
+ with Serializable {
+
+ private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024
+ private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
+ private val registrator = conf.getOption("spark.kryo.registrator")
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -48,7 +52,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
- kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true))
+ kryo.setReferences(referenceTracking)
for (cls <- KryoSerializer.toRegister) kryo.register(cls)
@@ -58,7 +62,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial
// Allow the user to register their own classes by setting spark.kryo.registrator
try {
- for (regCls <- conf.getOption("spark.kryo.registrator")) {
+ for (regCls <- registrator) {
logDebug("Running user registrator: " + regCls)
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index 16677ab54be04..099143494b851 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -23,21 +23,31 @@ import java.nio.ByteBuffer
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
+import org.apache.spark.SparkEnv
/**
* A serializer. Because some serialization libraries are not thread safe, this class is used to
* create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual
* serialization and are guaranteed to only be called from one thread at a time.
*
- * Implementations of this trait should have a zero-arg constructor or a constructor that accepts a
- * [[org.apache.spark.SparkConf]] as parameter. If both constructors are defined, the latter takes
- * precedence.
+ * Implementations of this trait should implement:
+ * 1. a zero-arg constructor or a constructor that accepts a [[org.apache.spark.SparkConf]]
+ * as parameter. If both constructors are defined, the latter takes precedence.
+ *
+ * 2. Java serialization interface.
*/
trait Serializer {
def newInstance(): SerializerInstance
}
+object Serializer {
+ def getSerializer(serializer: Serializer): Serializer = {
+ if (serializer == null) SparkEnv.get.serializer else serializer
+ }
+}
+
+
/**
* An instance of a serializer, for use by one thread at a time.
*/
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
deleted file mode 100644
index 65ac0155f45e7..0000000000000
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.serializer
-
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.spark.SparkConf
-
-/**
- * A service that returns a serializer object given the serializer's class name. If a previous
- * instance of the serializer object has been created, the get method returns that instead of
- * creating a new one.
- */
-private[spark] class SerializerManager {
- // TODO: Consider moving this into SparkConf itself to remove the global singleton.
-
- private val serializers = new ConcurrentHashMap[String, Serializer]
- private var _default: Serializer = _
-
- def default = _default
-
- def setDefault(clsName: String, conf: SparkConf): Serializer = {
- _default = get(clsName, conf)
- _default
- }
-
- def get(clsName: String, conf: SparkConf): Serializer = {
- if (clsName == null) {
- default
- } else {
- var serializer = serializers.get(clsName)
- if (serializer != null) {
- // If the serializer has been created previously, reuse that.
- serializer
- } else this.synchronized {
- // Otherwise, create a new one. But make sure no other thread has attempted
- // to create another new one at the same time.
- serializer = serializers.get(clsName)
- if (serializer == null) {
- val clsLoader = Thread.currentThread.getContextClassLoader
- val cls = Class.forName(clsName, true, clsLoader)
-
- // First try with the constructor that takes SparkConf. If we can't find one,
- // use a no-arg constructor instead.
- try {
- val constructor = cls.getConstructor(classOf[SparkConf])
- serializer = constructor.newInstance(conf).asInstanceOf[Serializer]
- } catch {
- case _: NoSuchMethodException =>
- val constructor = cls.getConstructor()
- serializer = constructor.newInstance().asInstanceOf[Serializer]
- }
-
- serializers.put(clsName, serializer)
- }
- serializer
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 925022e7fe6fb..bcfc39146a61e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -44,9 +44,13 @@ import org.apache.spark.util.Utils
*/
private[storage]
-trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])]
- with Logging with BlockFetchTracker {
+trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
def initialize()
+ def totalBlocks: Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def fetchWaitTime: Long
+ def remoteBytesRead: Long
}
@@ -74,7 +78,6 @@ object BlockFetcherIterator {
import blockManager._
private var _remoteBytesRead = 0L
- private var _remoteFetchTime = 0L
private var _fetchWaitTime = 0L
if (blocksByAddress == null) {
@@ -120,7 +123,6 @@ object BlockFetcherIterator {
future.onSuccess {
case Some(message) => {
val fetchDone = System.currentTimeMillis()
- _remoteFetchTime += fetchDone - fetchStart
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
for (blockMessage <- blockMessageArray) {
@@ -233,7 +235,15 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}
- //an iterator that will read fetched blocks off the queue as they arrive.
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
+ override def fetchWaitTime: Long = _fetchWaitTime
+ override def remoteBytesRead: Long = _remoteBytesRead
+
+
+ // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
+ // as they arrive.
@volatile protected var resultsGotten = 0
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
@@ -251,14 +261,6 @@ object BlockFetcherIterator {
}
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
-
- // Implementing BlockFetchTracker trait.
- override def totalBlocks: Int = numLocal + numRemote
- override def numLocalBlocks: Int = numLocal
- override def numRemoteBlocks: Int = numRemote
- override def remoteFetchTime: Long = _remoteFetchTime
- override def fetchWaitTime: Long = _fetchWaitTime
- override def remoteBytesRead: Long = _remoteBytesRead
}
// End of BasicBlockFetcherIterator
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index a734ddc1ef702..1bf3f4db32ea7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -29,19 +29,26 @@ import akka.actor.{ActorSystem, Cancellable, Props}
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
-import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException, SecurityManager}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.util._
+sealed trait Values
+
+case class ByteBufferValues(buffer: ByteBuffer) extends Values
+case class IteratorValues(iterator: Iterator[Any]) extends Values
+case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values
+
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
val master: BlockManagerMaster,
val defaultSerializer: Serializer,
maxMemory: Long,
- val conf: SparkConf)
+ val conf: SparkConf,
+ securityManager: SecurityManager)
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
@@ -60,7 +67,7 @@ private[spark] class BlockManager(
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
- val connectionManager = new ConnectionManager(0, conf)
+ val connectionManager = new ConnectionManager(0, conf, securityManager)
implicit val futureExecContext = connectionManager.futureExecContext
val blockManagerId = BlockManagerId(
@@ -116,8 +123,9 @@ private[spark] class BlockManager(
* Construct a BlockManager with a memory limit set based on system properties.
*/
def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
- serializer: Serializer, conf: SparkConf) = {
- this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf)
+ serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf,
+ securityManager)
}
/**
@@ -455,9 +463,7 @@ private[spark] class BlockManager(
def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
: Long = {
- val elements = new ArrayBuffer[Any]
- elements ++= values
- put(blockId, elements, level, tellMaster)
+ doPut(blockId, IteratorValues(values), level, tellMaster)
}
/**
@@ -479,7 +485,7 @@ private[spark] class BlockManager(
def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
tellMaster: Boolean = true) : Long = {
require(values != null, "Values is null")
- doPut(blockId, Left(values), level, tellMaster)
+ doPut(blockId, ArrayBufferValues(values), level, tellMaster)
}
/**
@@ -488,10 +494,11 @@ private[spark] class BlockManager(
def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
tellMaster: Boolean = true) {
require(bytes != null, "Bytes is null")
- doPut(blockId, Right(bytes), level, tellMaster)
+ doPut(blockId, ByteBufferValues(bytes), level, tellMaster)
}
- private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
+ private def doPut(blockId: BlockId,
+ data: Values,
level: StorageLevel, tellMaster: Boolean = true): Long = {
require(blockId != null, "BlockId is null")
require(level != null && level.isValid, "StorageLevel is null or invalid")
@@ -534,8 +541,9 @@ private[spark] class BlockManager(
// If we're storing bytes, then initiate the replication before storing them locally.
// This is faster as data is already serialized and ready to send.
- val replicationFuture = if (data.isRight && level.replication > 1) {
- val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ val replicationFuture = if (data.isInstanceOf[ByteBufferValues] && level.replication > 1) {
+ // Duplicate doesn't copy the bytes, just creates a wrapper
+ val bufferView = data.asInstanceOf[ByteBufferValues].buffer.duplicate()
Future {
replicate(blockId, bufferView, level)
}
@@ -549,34 +557,43 @@ private[spark] class BlockManager(
var marked = false
try {
- data match {
- case Left(values) => {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will
- // drop it to disk later if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
- }
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
- }
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = data match {
+ case IteratorValues(iterator) =>
+ memoryStore.putValues(blockId, iterator, level, true)
+ case ArrayBufferValues(array) =>
+ memoryStore.putValues(blockId, array, level, true)
+ case ByteBufferValues(bytes) => {
+ bytes.rewind();
+ memoryStore.putBytes(blockId, bytes, level)
+ }
+ }
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+
+ val res = data match {
+ case IteratorValues(iterator) =>
+ diskStore.putValues(blockId, iterator, level, askForBytes)
+ case ArrayBufferValues(array) =>
+ diskStore.putValues(blockId, array, level, askForBytes)
+ case ByteBufferValues(bytes) => {
+ bytes.rewind();
+ diskStore.putBytes(blockId, bytes, level)
}
}
- case Right(bytes) => {
- bytes.rewind()
- // Store it only in memory at first, even if useDisk is also set to true
- (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
- size = bytes.limit
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
}
}
@@ -605,8 +622,8 @@ private[spark] class BlockManager(
// values and need to serialize and replicate them now:
if (level.replication > 1) {
data match {
- case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
- case Left(values) => {
+ case ByteBufferValues(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case _ => {
val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index b047644b88f48..9a9be047c7245 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -28,7 +28,7 @@ import org.apache.spark.Logging
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) : PutResult
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -37,6 +37,9 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
+ def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel,
+ returnValues: Boolean) : PutResult
+
def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index d1f07ddb24bb2..36ee4bcc41c66 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -37,7 +37,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
diskManager.getBlockLocation(blockId).length
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
@@ -52,6 +52,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
+ return PutResult(bytes.limit(), Right(bytes.duplicate()))
}
override def putValues(
@@ -59,13 +60,22 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
+ : PutResult = {
+ return putValues(blockId, values.toIterator, level, returnValues)
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
: PutResult = {
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
- blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ blockManager.dataSerializeStream(blockId, outputStream, values)
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 18141756518c5..38836d44b04e8 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -49,7 +49,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) : PutResult = {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -59,8 +59,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
elements ++= values
val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
tryToPut(blockId, elements, sizeEstimate, true)
+ PutResult(sizeEstimate, Left(values.toIterator))
} else {
tryToPut(blockId, bytes, bytes.limit, false)
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
}
}
@@ -69,14 +71,33 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
- : PutResult = {
-
+ : PutResult = {
if (level.deserialized) {
val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
tryToPut(blockId, values, sizeEstimate, true)
- PutResult(sizeEstimate, Left(values.iterator))
+ PutResult(sizeEstimate, Left(values.toIterator))
+ } else {
+ val bytes = blockManager.dataSerialize(blockId, values.toIterator)
+ tryToPut(blockId, bytes, bytes.limit, false)
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
+ }
+ }
+
+ override def putValues(
+ blockId: BlockId,
+ values: Iterator[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
+ : PutResult = {
+
+ if (level.deserialized) {
+ val valueEntries = new ArrayBuffer[Any]()
+ valueEntries ++= values
+ val sizeEstimate = SizeEstimator.estimate(valueEntries.asInstanceOf[AnyRef])
+ tryToPut(blockId, valueEntries, sizeEstimate, true)
+ PutResult(sizeEstimate, Left(valueEntries.toIterator))
} else {
- val bytes = blockManager.dataSerialize(blockId, values.iterator)
+ val bytes = blockManager.dataSerialize(blockId, values)
tryToPut(blockId, bytes, bytes.limit, false)
PutResult(bytes.limit(), Right(bytes.duplicate()))
}
@@ -215,13 +236,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
val pair = iterator.next()
val blockId = pair.getKey
- if (rddToAdd.isDefined && rddToAdd == getRddId(blockId)) {
- logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
- "block from the same RDD")
- return false
+ if (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) {
+ selectedBlocks += blockId
+ selectedMemory += pair.getValue.size
}
- selectedBlocks += blockId
- selectedMemory += pair.getValue.size
}
}
@@ -243,6 +261,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
return true
} else {
+ logInfo(s"Will not store $blockIdToAdd as it would require dropping another block " +
+ "from the same RDD")
return false
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 1d81d006c0b29..36f2a0fd02724 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -24,6 +24,7 @@ import util.Random
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.{SecurityManager, SparkConf}
/**
* This class tests the BlockManager and MemoryStore for thread safety and
@@ -98,7 +99,8 @@ private[spark] object ThreadingTest {
val blockManagerMaster = new BlockManagerMaster(
actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf)
val blockManager = new BlockManager(
- "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf)
+ "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
+ new SecurityManager(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 1f048a84cdfb6..e0555ca7ac02f 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -18,18 +18,23 @@
package org.apache.spark.ui
import java.net.InetSocketAddress
-import javax.servlet.http.{HttpServletResponse, HttpServletRequest}
+import java.net.URL
+import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest}
import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}
import scala.xml.Node
-import net.liftweb.json.{JValue, pretty, render}
-import org.eclipse.jetty.server.{Handler, Request, Server}
-import org.eclipse.jetty.server.handler.{AbstractHandler, ContextHandler, HandlerList, ResourceHandler}
+import org.json4s.JValue
+import org.json4s.jackson.JsonMethods.{pretty, render}
+
+import org.eclipse.jetty.server.{DispatcherType, Server}
+import org.eclipse.jetty.server.handler.HandlerList
+import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder}
import org.eclipse.jetty.util.thread.QueuedThreadPool
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+
/** Utilities for launching a web server using Jetty's HTTP Server class */
private[spark] object JettyUtils extends Logging {
@@ -38,57 +43,107 @@ private[spark] object JettyUtils extends Logging {
type Responder[T] = HttpServletRequest => T
- // Conversions from various types of Responder's to jetty Handlers
- implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler =
- createHandler(responder, "text/json", (in: JValue) => pretty(render(in)))
+ class ServletParams[T <% AnyRef](val responder: Responder[T],
+ val contentType: String,
+ val extractFn: T => String = (in: Any) => in.toString) {}
+
+ // Conversions from various types of Responder's to appropriate servlet parameters
+ implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] =
+ new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in)))
- implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler =
- createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString)
+ implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] =
+ new ServletParams(responder, "text/html", (in: Seq[Node]) => "" + in.toString)
- implicit def textResponderToHandler(responder: Responder[String]): Handler =
- createHandler(responder, "text/plain")
+ implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] =
+ new ServletParams(responder, "text/plain")
- def createHandler[T <% AnyRef](responder: Responder[T], contentType: String,
- extractFn: T => String = (in: Any) => in.toString): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
+ def createServlet[T <% AnyRef](servletParams: ServletParams[T],
+ securityMgr: SecurityManager): HttpServlet = {
+ new HttpServlet {
+ override def doGet(request: HttpServletRequest,
response: HttpServletResponse) {
- response.setContentType("%s;charset=utf-8".format(contentType))
- response.setStatus(HttpServletResponse.SC_OK)
- baseRequest.setHandled(true)
- val result = responder(request)
- response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.getWriter().println(extractFn(result))
+ if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) {
+ response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
+ response.setStatus(HttpServletResponse.SC_OK)
+ val result = servletParams.responder(request)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.getWriter().println(servletParams.extractFn(result))
+ } else {
+ response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
+ "User is not authorized to access this page.");
+ }
}
}
}
+ def createServletHandler(path: String, servlet: HttpServlet): ServletContextHandler = {
+ val contextHandler = new ServletContextHandler()
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
+ }
+
/** Creates a handler that always redirects the user to a given path */
- def createRedirectHandler(newPath: String): Handler = {
- new AbstractHandler {
- def handle(target: String,
- baseRequest: Request,
- request: HttpServletRequest,
+ def createRedirectHandler(newPath: String, path: String): ServletContextHandler = {
+ val servlet = new HttpServlet {
+ override def doGet(request: HttpServletRequest,
response: HttpServletResponse) {
- response.setStatus(302)
- response.setHeader("Location", baseRequest.getRootURL + newPath)
- baseRequest.setHandled(true)
+ // make sure we don't end up with // in the middle
+ val newUri = new URL(new URL(request.getRequestURL.toString), newPath).toURI
+ response.sendRedirect(newUri.toString)
}
}
+ val contextHandler = new ServletContextHandler()
+ val holder = new ServletHolder(servlet)
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
}
/** Creates a handler for serving files from a static directory */
- def createStaticHandler(resourceBase: String): ResourceHandler = {
- val staticHandler = new ResourceHandler
+ def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = {
+ val contextHandler = new ServletContextHandler()
+ val staticHandler = new DefaultServlet
+ val holder = new ServletHolder(staticHandler)
Option(getClass.getClassLoader.getResource(resourceBase)) match {
case Some(res) =>
- staticHandler.setResourceBase(res.toString)
+ holder.setInitParameter("resourceBase", res.toString)
+ holder.setInitParameter("welcomeServlets", "false")
+ holder.setInitParameter("pathInfoOnly", "false")
case None =>
throw new Exception("Could not find resource path for Web UI: " + resourceBase)
}
- staticHandler
+ contextHandler.setContextPath(path)
+ contextHandler.addServlet(holder, "/")
+ contextHandler
+ }
+
+ private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) {
+ val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim())
+ filters.foreach {
+ case filter : String =>
+ if (!filter.isEmpty) {
+ logInfo("Adding filter: " + filter)
+ val holder : FilterHolder = new FilterHolder()
+ holder.setClassName(filter)
+ // get any parameters for each filter
+ val paramName = "spark." + filter + ".params"
+ val params = conf.get(paramName, "").split(',').map(_.trim()).toSet
+ params.foreach {
+ case param : String =>
+ if (!param.isEmpty) {
+ val parts = param.split("=")
+ if (parts.length == 2) holder.setInitParameter(parts(0), parts(1))
+ }
+ }
+ val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR,
+ DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST)
+ handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) }
+ }
+ }
}
/**
@@ -98,17 +153,12 @@ private[spark] object JettyUtils extends Logging {
* If the desired port number is contented, continues incrementing ports until a free port is
* found. Returns the chosen port and the jetty Server object.
*/
- def startJettyServer(hostName: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int)
- = {
-
- val handlersToRegister = handlers.map { case(path, handler) =>
- val contextHandler = new ContextHandler(path)
- contextHandler.setHandler(handler)
- contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler]
- }
+ def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler],
+ conf: SparkConf): (Server, Int) = {
+ addFilters(handlers, conf)
val handlerList = new HandlerList
- handlerList.setHandlers(handlersToRegister.toArray)
+ handlerList.setHandlers(handlers.toArray)
@tailrec
def connect(currentPort: Int): (Server, Int) = {
@@ -118,7 +168,9 @@ private[spark] object JettyUtils extends Logging {
server.setThreadPool(pool)
server.setHandler(handlerList)
- Try { server.start() } match {
+ Try {
+ server.start()
+ } match {
case s: Success[_] =>
(server, server.getConnectors.head.getLocalPort)
case f: Failure[_] =>
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index af6b65860e006..5f0dee64fedb7 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -17,7 +17,10 @@
package org.apache.spark.ui
-import org.eclipse.jetty.server.{Handler, Server}
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.ui.JettyUtils._
@@ -34,9 +37,9 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
var boundPort: Option[Int] = None
var server: Option[Server] = None
- val handlers = Seq[(String, Handler)](
- ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)),
- ("/", createRedirectHandler("/stages"))
+ val handlers = Seq[ServletContextHandler] (
+ createStaticHandler(SparkUI.STATIC_RESOURCE_DIR + "/static", "/static"),
+ createRedirectHandler("/stages", "/")
)
val storage = new BlockManagerUI(sc)
val jobs = new JobProgressUI(sc)
@@ -52,7 +55,7 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
/** Bind the HTTP server which backs this web interface */
def bind() {
try {
- val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers)
+ val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers, sc.conf)
logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort))
server = Some(srv)
boundPort = Some(usedPort)
@@ -83,5 +86,5 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging {
private[spark] object SparkUI {
val DEFAULT_PORT = "4040"
- val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui"
}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
index 9e7cdc88162e8..14333476c0e31 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
import scala.util.Properties
import scala.xml.Node
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.SparkContext
import org.apache.spark.ui.JettyUtils._
@@ -32,8 +32,9 @@ import org.apache.spark.ui.UIUtils
private[spark] class EnvironmentUI(sc: SparkContext) {
- def getHandlers = Seq[(String, Handler)](
- ("/environment", (request: HttpServletRequest) => envDetails(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/environment",
+ createServlet((request: HttpServletRequest) => envDetails(request), sc.env.securityManager))
)
def envDetails(request: HttpServletRequest): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index 1f3b7a4c231b6..4235cfeff9fa2 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest
import scala.collection.mutable.{HashMap, HashSet}
import scala.xml.Node
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{ExceptionFailure, Logging, SparkContext}
import org.apache.spark.executor.TaskMetrics
@@ -43,8 +43,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
sc.addSparkListener(listener)
}
- def getHandlers = Seq[(String, Handler)](
- ("/executors", (request: HttpServletRequest) => render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/executors", createServlet((request: HttpServletRequest) => render
+ (request), sc.env.securityManager))
)
def render(request: HttpServletRequest): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
index 557bce6b66353..2d95d47e154cd 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest
import scala.Seq
import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.SparkContext
import org.apache.spark.ui.JettyUtils._
@@ -45,9 +46,15 @@ private[spark] class JobProgressUI(val sc: SparkContext) {
def formatDuration(ms: Long) = Utils.msDurationToString(ms)
- def getHandlers = Seq[(String, Handler)](
- ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)),
- ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)),
- ("/stages", (request: HttpServletRequest) => indexPage.render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/stages/stage",
+ createServlet((request: HttpServletRequest) => stagePage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/stages/pool",
+ createServlet((request: HttpServletRequest) => poolPage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/stages",
+ createServlet((request: HttpServletRequest) => indexPage.render(request),
+ sc.env.securityManager))
)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
index dc18eab74e0da..cb2083eb019bf 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ui.storage
import javax.servlet.http.HttpServletRequest
-import org.eclipse.jetty.server.Handler
+import org.eclipse.jetty.servlet.ServletContextHandler
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.ui.JettyUtils._
@@ -29,8 +29,12 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging {
val indexPage = new IndexPage(this)
val rddPage = new RDDPage(this)
- def getHandlers = Seq[(String, Handler)](
- ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)),
- ("/storage", (request: HttpServletRequest) => indexPage.render(request))
+ def getHandlers = Seq[ServletContextHandler](
+ createServletHandler("/storage/rdd",
+ createServlet((request: HttpServletRequest) => rddPage.render(request),
+ sc.env.securityManager)),
+ createServletHandler("/storage",
+ createServlet((request: HttpServletRequest) => indexPage.render(request),
+ sc.env.securityManager))
)
}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index f26ed47e58046..d0ff17db632c1 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -24,12 +24,12 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem}
import com.typesafe.config.ConfigFactory
import org.apache.log4j.{Level, Logger}
-import org.apache.spark.SparkConf
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
/**
* Various utility classes for working with Akka.
*/
-private[spark] object AkkaUtils {
+private[spark] object AkkaUtils extends Logging {
/**
* Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
@@ -42,14 +42,14 @@ private[spark] object AkkaUtils {
* of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false,
- conf: SparkConf): (ActorSystem, Int) = {
+ conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = {
val akkaThreads = conf.getInt("spark.akka.threads", 4)
val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15)
val akkaTimeout = conf.getInt("spark.akka.timeout", 100)
- val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10)
+ val akkaFrameSize = maxFrameSizeBytes(conf)
val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
if (!akkaLogLifecycleEvents) {
@@ -65,6 +65,15 @@ private[spark] object AkkaUtils {
conf.getDouble("spark.akka.failure-detector.threshold", 300.0)
val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000)
+ val secretKey = securityManager.getSecretKey()
+ val isAuthOn = securityManager.isAuthenticationEnabled()
+ if (isAuthOn && secretKey == null) {
+ throw new Exception("Secret key is null with authentication on")
+ }
+ val requireCookie = if (isAuthOn) "on" else "off"
+ val secureCookie = if (isAuthOn) secretKey else ""
+ logDebug("In createActorSystem, requireCookie is: " + requireCookie)
+
val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback(
ConfigFactory.parseString(
s"""
@@ -72,6 +81,8 @@ private[spark] object AkkaUtils {
|akka.loggers = [""akka.event.slf4j.Slf4jLogger""]
|akka.stdout-loglevel = "ERROR"
|akka.jvm-exit-on-fatal-error = off
+ |akka.remote.require-cookie = "$requireCookie"
+ |akka.remote.secure-cookie = "$secureCookie"
|akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s
|akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s
|akka.remote.transport-failure-detector.threshold = $akkaFailureDetector
@@ -81,7 +92,7 @@ private[spark] object AkkaUtils {
|akka.remote.netty.tcp.port = $port
|akka.remote.netty.tcp.tcp-nodelay = on
|akka.remote.netty.tcp.connection-timeout = $akkaTimeout s
- |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB
+ |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}B
|akka.remote.netty.tcp.execution-pool-size = $akkaThreads
|akka.actor.default-dispatcher.throughput = $akkaBatchSize
|akka.log-config-on-start = $logAkkaConfig
@@ -110,4 +121,9 @@ private[spark] object AkkaUtils {
def lookupTimeout(conf: SparkConf): FiniteDuration = {
Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds")
}
+
+ /** Returns the configured max frame size for Akka messages in bytes. */
+ def maxFrameSizeBytes(conf: SparkConf): Int = {
+ conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 681d0a30cb3f8..a8d20ee332355 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set
-import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
-import org.objectweb.asm.Opcodes._
+import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
+import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
import org.apache.spark.Logging
diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
index bf71882ef770a..c539d2f708f95 100644
--- a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
+++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
@@ -23,9 +23,9 @@ import scala.util.control.{ControlThrowable, NonFatal}
import com.typesafe.config.Config
/**
- * An [[akka.actor.ActorSystem]] which refuses to shut down in the event of a fatal exception.
+ * An akka.actor.ActorSystem which refuses to shut down in the event of a fatal exception
* This is necessary as Spark Executors are allowed to recover from fatal exceptions
- * (see [[org.apache.spark.executor.Executor]]).
+ * (see org.apache.spark.executor.Executor)
*/
object IndestructibleActorSystem {
def apply(name: String, config: Config): ActorSystem =
diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
index b053266f12748..2c1a6f8fd0a44 100644
--- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala
+++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
@@ -25,10 +25,20 @@ package org.apache.spark.util
* @param _2 Element 2 of this MutablePair
*/
case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1,
- @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
+ @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
(var _1: T1, var _2: T2)
extends Product2[T1, T2]
{
+ /** No-arg constructor for serialization */
+ def this() = this(null.asInstanceOf[T1], null.asInstanceOf[T2])
+
+ /** Updates this pair with new values and returns itself */
+ def update(n1: T1, n2: T2): MutablePair[T1, T2] = {
+ _1 = n1
+ _2 = n2
+ this
+ }
+
override def toString = "(" + _1 + "," + _2 + ")"
override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]]
diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
index 5b0d2c36510b8..732748a7ff82b 100644
--- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala
+++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
@@ -19,9 +19,9 @@ package org.apache.spark.util
/**
* A class for tracking the statistics of a set of numbers (count, mean and variance) in a
- * numerically robust way. Includes support for merging two StatCounters. Based on
- * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
- * Welford and Chan's algorithms for running variance]].
+ * numerically robust way. Includes support for merging two StatCounters. Based on Welford
+ * and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]]
+ * for running variance.
*
* @constructor Initialize the StatCounter with the given values.
*/
@@ -29,6 +29,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
private var n: Long = 0 // Running count of our values
private var mu: Double = 0 // Running mean of our values
private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
+ private var maxValue: Double = Double.NegativeInfinity // Running max of our values
+ private var minValue: Double = Double.PositiveInfinity // Running min of our values
merge(values)
@@ -41,6 +43,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
n += 1
mu += delta / n
m2 += delta * (value - mu)
+ maxValue = math.max(maxValue, value)
+ minValue = math.min(minValue, value)
this
}
@@ -58,7 +62,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (n == 0) {
mu = other.mu
m2 = other.m2
- n = other.n
+ n = other.n
+ maxValue = other.maxValue
+ minValue = other.minValue
} else if (other.n != 0) {
val delta = other.mu - mu
if (other.n * 10 < n) {
@@ -70,6 +76,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
+ maxValue = math.max(maxValue, other.maxValue)
+ minValue = math.min(minValue, other.minValue)
}
this
}
@@ -81,6 +89,8 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
other.n = n
other.mu = mu
other.m2 = m2
+ other.maxValue = maxValue
+ other.minValue = minValue
other
}
@@ -90,6 +100,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
def sum: Double = n * mu
+ def max: Double = maxValue
+
+ def min: Double = minValue
+
/** Return the variance of the values. */
def variance: Double = {
if (n == 0) {
@@ -121,7 +135,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
def sampleStdev: Double = math.sqrt(sampleVariance)
override def toString: String = {
- "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
+ "(count: %d, mean: %f, stdev: %f, max: %f, min: %f)".format(count, mean, stdev, max, min)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 8e69f1d3351b5..38a275d438959 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL}
+import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection}
import java.nio.ByteBuffer
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
@@ -33,10 +33,11 @@ import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
+
/**
* Various utility methods used by Spark.
*/
@@ -232,6 +233,22 @@ private[spark] object Utils extends Logging {
}
}
+ /**
+ * Construct a URI container information used for authentication.
+ * This also sets the default authenticator to properly negotiation the
+ * user/password based on the URI.
+ *
+ * Note this relies on the Authenticator.setDefault being set properly to decode
+ * the user name and password. This is currently set in the SecurityManager.
+ */
+ def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = {
+ val userCred = securityMgr.getSecretKey()
+ if (userCred == null) throw new Exception("Secret key is null with authentication on")
+ val userInfo = securityMgr.getHttpUser() + ":" + userCred
+ new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(),
+ uri.getQuery(), uri.getFragment())
+ }
+
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
@@ -239,7 +256,7 @@ private[spark] object Utils extends Logging {
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- def fetchFile(url: String, targetDir: File, conf: SparkConf) {
+ def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) {
val filename = url.split("/").last
val tempDir = getLocalDir(conf)
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
@@ -249,7 +266,23 @@ private[spark] object Utils extends Logging {
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + tempFile)
- val in = new URL(url).openStream()
+
+ var uc: URLConnection = null
+ if (securityMgr.isAuthenticationEnabled()) {
+ logDebug("fetchFile with security enabled")
+ val newuri = constructURIForAuthentication(uri, securityMgr)
+ uc = newuri.toURL().openConnection()
+ uc.setAllowUserInteraction(false)
+ } else {
+ logDebug("fetchFile not using security")
+ uc = new URL(url).openConnection()
+ }
+
+ val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000
+ uc.setConnectTimeout(timeout)
+ uc.setReadTimeout(timeout)
+ uc.connect()
+ val in = uc.getInputStream();
val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
@@ -503,8 +536,6 @@ private[spark] object Utils extends Logging {
/**
* Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
- * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
- * environment variable.
*/
def memoryStringToMb(str: String): Int = {
val lower = str.toLowerCase
@@ -688,8 +719,8 @@ private[spark] object Utils extends Logging {
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
- def formatSparkCallSite = {
- val callSiteInfo = getCallSiteInfo
+ /** Returns a printable version of the call site info suitable for logs. */
+ def formatCallSiteInfo(callSiteInfo: CallSiteInfo = Utils.getCallSiteInfo) = {
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
callSiteInfo.firstUserLine)
}
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
index d437c055f33d4..dc4b8f253f259 100644
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -136,7 +136,7 @@ object Vector {
/**
* Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers
- * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided.
+ * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided.
*/
def random(length: Int, random: Random = new XORShiftRandom()) =
Vector(length, _ => random.nextDouble())
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index ed74a31f05bae..caa06d5b445b4 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -60,7 +60,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
- serializer: Serializer = SparkEnv.get.serializerManager.default,
+ serializer: Serializer = SparkEnv.get.serializer,
blockManager: BlockManager = SparkEnv.get.blockManager)
extends Iterable[(K, C)] with Serializable with Logging {
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index ca611b67ed91d..8a4cdea2fa7b1 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -17,8 +17,11 @@
package org.apache.spark.util.random
+import java.nio.ByteBuffer
import java.util.{Random => JavaRandom}
+import scala.util.hashing.MurmurHash3
+
import org.apache.spark.util.Utils.timeIt
/**
@@ -36,8 +39,8 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
def this() = this(System.nanoTime)
- private var seed = init
-
+ private var seed = XORShiftRandom.hashSeed(init)
+
// we need to just override next - this will be called by nextInt, nextDouble,
// nextGaussian, nextLong, etc.
override protected def next(bits: Int): Int = {
@@ -49,13 +52,19 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
}
override def setSeed(s: Long) {
- seed = s
+ seed = XORShiftRandom.hashSeed(s)
}
}
/** Contains benchmark method and main method to run benchmark of the RNG */
private[spark] object XORShiftRandom {
+ /** Hash seeds to have 0/1 bits throughout. */
+ private def hashSeed(seed: Long): Long = {
+ val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array()
+ MurmurHash3.bytesHash(bytes)
+ }
+
/**
* Main method for running benchmark
* @param args takes one argument - the number of random numbers to generate
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 20232e9fbb8d0..40e853c39ca99 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -75,8 +75,9 @@ public int compare(Integer a, Integer b) {
else if (a < b) return 1;
else return 0;
}
- };
+ }
+ @SuppressWarnings("unchecked")
@Test
public void sparkContextUnion() {
// Union of non-specialized JavaRDDs
@@ -109,6 +110,37 @@ public void sparkContextUnion() {
Assert.assertEquals(4, pUnion.count());
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void intersection() {
+ List ints1 = Arrays.asList(1, 10, 2, 3, 4, 5);
+ List ints2 = Arrays.asList(1, 6, 2, 3, 7, 8);
+ JavaRDD s1 = sc.parallelize(ints1);
+ JavaRDD s2 = sc.parallelize(ints2);
+
+ JavaRDD intersections = s1.intersection(s2);
+ Assert.assertEquals(3, intersections.count());
+
+ ArrayList list = new ArrayList();
+ JavaRDD empty = sc.parallelize(list);
+ JavaRDD emptyIntersection = empty.intersection(s2);
+ Assert.assertEquals(0, emptyIntersection.count());
+
+ List doubles = Arrays.asList(1.0, 2.0);
+ JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD dIntersection = d1.intersection(d2);
+ Assert.assertEquals(2, dIntersection.count());
+
+ List> pairs = new ArrayList>();
+ pairs.add(new Tuple2(1, 2));
+ pairs.add(new Tuple2(3, 4));
+ JavaPairRDD p1 = sc.parallelizePairs(pairs);
+ JavaPairRDD p2 = sc.parallelizePairs(pairs);
+ JavaPairRDD pIntersection = p1.intersection(p2);
+ Assert.assertEquals(2, pIntersection.count());
+ }
+
@Test
public void sortByKey() {
List> pairs = new ArrayList>();
@@ -148,6 +180,7 @@ public void call(String s) {
Assert.assertEquals(2, foreachCalls);
}
+ @SuppressWarnings("unchecked")
@Test
public void lookup() {
JavaPairRDD categories = sc.parallelizePairs(Arrays.asList(
@@ -179,6 +212,7 @@ public Boolean call(Integer x) {
Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds
}
+ @SuppressWarnings("unchecked")
@Test
public void cogroup() {
JavaPairRDD categories = sc.parallelizePairs(Arrays.asList(
@@ -197,6 +231,7 @@ public void cogroup() {
cogrouped.collect();
}
+ @SuppressWarnings("unchecked")
@Test
public void leftOuterJoin() {
JavaPairRDD rdd1 = sc.parallelizePairs(Arrays.asList(
@@ -243,6 +278,7 @@ public Integer call(Integer a, Integer b) {
Assert.assertEquals(33, sum);
}
+ @SuppressWarnings("unchecked")
@Test
public void foldByKey() {
List> pairs = Arrays.asList(
@@ -265,6 +301,7 @@ public Integer call(Integer a, Integer b) {
Assert.assertEquals(3, sums.lookup(3).get(0).intValue());
}
+ @SuppressWarnings("unchecked")
@Test
public void reduceByKey() {
List> pairs = Arrays.asList(
@@ -320,8 +357,8 @@ public void approximateResults() {
public void take() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
Assert.assertEquals(1, rdd.first().intValue());
- List firstTwo = rdd.take(2);
- List sample = rdd.takeSample(false, 2, 42);
+ rdd.take(2);
+ rdd.takeSample(false, 2, 42);
}
@Test
@@ -359,8 +396,8 @@ public Boolean call(Double x) {
Assert.assertEquals(2.49444, rdd.stdev(), 0.01);
Assert.assertEquals(2.73252, rdd.sampleStdev(), 0.01);
- Double first = rdd.first();
- List take = rdd.take(5);
+ rdd.first();
+ rdd.take(5);
}
@Test
@@ -380,14 +417,14 @@ public void javaDoubleRDDHistoGram() {
@Test
public void map() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
- JavaDoubleRDD doubles = rdd.map(new DoubleFunction() {
+ JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() {
@Override
- public Double call(Integer x) {
+ public double call(Integer x) {
return 1.0 * x;
}
}).cache();
doubles.collect();
- JavaPairRDD pairs = rdd.map(new PairFunction() {
+ JavaPairRDD pairs = rdd.mapToPair(new PairFunction() {
@Override
public Tuple2 call(Integer x) {
return new Tuple2(x, x);
@@ -416,7 +453,7 @@ public Iterable call(String x) {
Assert.assertEquals("Hello", words.first());
Assert.assertEquals(11, words.count());
- JavaPairRDD pairs = rdd.flatMap(
+ JavaPairRDD pairs = rdd.flatMapToPair(
new PairFlatMapFunction() {
@Override
@@ -430,7 +467,7 @@ public Iterable> call(String s) {
Assert.assertEquals(new Tuple2("Hello", "Hello"), pairs.first());
Assert.assertEquals(11, pairs.count());
- JavaDoubleRDD doubles = rdd.flatMap(new DoubleFlatMapFunction() {
+ JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() {
@Override
public Iterable call(String s) {
List lengths = new LinkedList();
@@ -438,11 +475,11 @@ public Iterable call(String s) {
return lengths;
}
});
- Double x = doubles.first();
- Assert.assertEquals(5.0, doubles.first().doubleValue(), 0.01);
+ Assert.assertEquals(5.0, doubles.first(), 0.01);
Assert.assertEquals(11, pairs.count());
}
+ @SuppressWarnings("unchecked")
@Test
public void mapsFromPairsToPairs() {
List> pairs = Arrays.asList(
@@ -453,7 +490,7 @@ public void mapsFromPairsToPairs() {
JavaPairRDD pairRDD = sc.parallelizePairs(pairs);
// Regression test for SPARK-668:
- JavaPairRDD swapped = pairRDD.flatMap(
+ JavaPairRDD swapped = pairRDD.flatMapToPair(
new PairFlatMapFunction, String, Integer>() {
@Override
public Iterable> call(Tuple2 item) throws Exception {
@@ -463,7 +500,7 @@ public Iterable> call(Tuple2 item) thro
swapped.collect();
// There was never a bug here, but it's worth testing:
- pairRDD.map(new PairFunction, String, Integer>() {
+ pairRDD.mapToPair(new PairFunction, String, Integer>() {
@Override
public Tuple2 call(Tuple2 item) throws Exception {
return item.swap();
@@ -509,6 +546,7 @@ public void repartition() {
}
}
+ @SuppressWarnings("unchecked")
@Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
@@ -573,6 +611,7 @@ public void textFilesCompressed() throws IOException {
Assert.assertEquals(expected, readRDD.collect());
}
+ @SuppressWarnings("unchecked")
@Test
public void sequenceFile() {
File tempDir = Files.createTempDir();
@@ -584,7 +623,7 @@ public void sequenceFile() {
);
JavaPairRDD rdd = sc.parallelizePairs(pairs);
- rdd.map(new PairFunction, IntWritable, Text>() {
+ rdd.mapToPair(new PairFunction, IntWritable, Text>() {
@Override
public Tuple2 call(Tuple2 pair) {
return new Tuple2(new IntWritable(pair._1()), new Text(pair._2()));
@@ -593,7 +632,7 @@ public Tuple2 call(Tuple2 pair) {
// Try reading the output back as an object file
JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class,
- Text.class).map(new PairFunction, Integer, String>() {
+ Text.class).mapToPair(new PairFunction, Integer, String>() {
@Override
public Tuple2 call(Tuple2 pair) {
return new Tuple2(pair._1().get(), pair._2().toString());
@@ -602,6 +641,7 @@ public Tuple2 call(Tuple2 pair) {
Assert.assertEquals(pairs, readRDD.collect());
}
+ @SuppressWarnings("unchecked")
@Test
public void writeWithNewAPIHadoopFile() {
File tempDir = Files.createTempDir();
@@ -613,7 +653,7 @@ public void writeWithNewAPIHadoopFile() {
);
JavaPairRDD rdd = sc.parallelizePairs(pairs);
- rdd.map(new PairFunction, IntWritable, Text>() {
+ rdd.mapToPair(new PairFunction, IntWritable, Text>() {
@Override
public Tuple2 call(Tuple2 pair) {
return new Tuple2(new IntWritable(pair._1()), new Text(pair._2()));
@@ -632,6 +672,7 @@ public String call(Tuple2 x) {
}).collect().toString());
}
+ @SuppressWarnings("unchecked")
@Test
public void readWithNewAPIHadoopFile() throws IOException {
File tempDir = Files.createTempDir();
@@ -643,7 +684,7 @@ public void readWithNewAPIHadoopFile() throws IOException {
);
JavaPairRDD rdd = sc.parallelizePairs(pairs);
- rdd.map(new PairFunction, IntWritable, Text>() {
+ rdd.mapToPair(new PairFunction, IntWritable, Text>() {
@Override
public Tuple2 call(Tuple2 pair) {
return new Tuple2(new IntWritable(pair._1()), new Text(pair._2()));
@@ -674,6 +715,7 @@ public void objectFilesOfInts() {
Assert.assertEquals(expected, readRDD.collect());
}
+ @SuppressWarnings("unchecked")
@Test
public void objectFilesOfComplexTypes() {
File tempDir = Files.createTempDir();
@@ -690,6 +732,7 @@ public void objectFilesOfComplexTypes() {
Assert.assertEquals(pairs, readRDD.collect());
}
+ @SuppressWarnings("unchecked")
@Test
public void hadoopFile() {
File tempDir = Files.createTempDir();
@@ -701,7 +744,7 @@ public void hadoopFile() {
);
JavaPairRDD rdd = sc.parallelizePairs(pairs);
- rdd.map(new PairFunction, IntWritable, Text>() {
+ rdd.mapToPair(new PairFunction, IntWritable, Text>() {
@Override
public Tuple2 call(Tuple2 pair) {
return new Tuple2(new IntWritable(pair._1()), new Text(pair._2()));
@@ -719,6 +762,7 @@ public String call(Tuple2 x) {
}).collect().toString());
}
+ @SuppressWarnings("unchecked")
@Test
public void hadoopFileCompressed() {
File tempDir = Files.createTempDir();
@@ -730,7 +774,7 @@ public void hadoopFileCompressed() {
);
JavaPairRDD rdd = sc.parallelizePairs(pairs);
- rdd.map(new PairFunction, IntWritable, Text>() {
+ rdd.mapToPair(new PairFunction, IntWritable, Text>() {
@Override
public Tuple2 call(Tuple2 pair) {
return new Tuple2(new IntWritable(pair._1()), new Text(pair._2()));
@@ -753,9 +797,9 @@ public String call(Tuple2 x) {
@Test
public void zip() {
JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
- JavaDoubleRDD doubles = rdd.map(new DoubleFunction() {
+ JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() {
@Override
- public Double call(Integer x) {
+ public double call(Integer x) {
return 1.0 * x;
}
});
@@ -824,7 +868,7 @@ public Float zero(Float initialValue) {
}
};
- final Accumulator floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
+ final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam);
rdd.foreach(new VoidFunction() {
public void call(Integer x) {
floatAccum.add((float) x);
@@ -876,16 +920,17 @@ public void checkpointAndRestore() {
Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
}
+ @SuppressWarnings("unchecked")
@Test
public void mapOnPairRDD() {
JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4));
- JavaPairRDD rdd2 = rdd1.map(new PairFunction() {
+ JavaPairRDD rdd2 = rdd1.mapToPair(new PairFunction() {
@Override
public Tuple2 call(Integer i) throws Exception {
return new Tuple2(i, i % 2);
}
});
- JavaPairRDD rdd3 = rdd2.map(
+ JavaPairRDD rdd3 = rdd2.mapToPair(
new PairFunction, Integer, Integer>() {
@Override
public Tuple2 call(Tuple2 in) throws Exception {
@@ -900,11 +945,12 @@ public Tuple2 call(Tuple2 in) throws Excepti
}
+ @SuppressWarnings("unchecked")
@Test
public void collectPartitions() {
JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
- JavaPairRDD rdd2 = rdd1.map(new PairFunction() {
+ JavaPairRDD rdd2 = rdd1.mapToPair(new PairFunction() {
@Override
public Tuple2 call(Integer i) throws Exception {
return new Tuple2(i, i % 2);
@@ -968,14 +1014,14 @@ public void countApproxDistinctByKey() {
@Test
public void collectAsMapWithIntArrayValues() {
// Regression test for SPARK-1040
- JavaRDD rdd = sc.parallelize(Arrays.asList(new Integer[] { 1 }));
- JavaPairRDD pairRDD = rdd.map(new PairFunction() {
+ JavaRDD rdd = sc.parallelize(Arrays.asList(1));
+ JavaPairRDD pairRDD = rdd.mapToPair(new PairFunction() {
@Override
public Tuple2 call(Integer x) throws Exception {
return new Tuple2(x, new int[] { x });
}
});
pairRDD.collect(); // Works fine
- Map map = pairRDD.collectAsMap(); // Used to crash with ClassCastException
+ pairRDD.collectAsMap(); // Used to crash with ClassCastException
}
}
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
new file mode 100644
index 0000000000000..d2e303d81c4c8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import akka.actor._
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.AkkaUtils
+import scala.concurrent.Await
+
+/**
+ * Test the AkkaUtils with various security settings.
+ */
+class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
+
+ test("remote fetch security bad password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf)
+
+ assert(securityManagerBad.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = conf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "bad")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === false)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "good")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security off
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security pass") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val goodconf = new SparkConf
+ goodconf.set("spark.authenticate", "true")
+ goodconf.set("spark.authenticate.secret", "good")
+ val securityManagerGood = new SecurityManager(goodconf);
+
+ assert(securityManagerGood.isAuthenticationEnabled() === true)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = goodconf, securityManager = securityManagerGood)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ // this should succeed since security on and passwords match
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+ test("remote fetch security off client") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+
+ val securityManager = new SecurityManager(conf);
+
+ val hostname = "localhost"
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0,
+ conf = conf, securityManager = securityManager)
+ System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+
+ assert(securityManager.isAuthenticationEnabled() === true)
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ masterTracker.trackerActor = actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ badconf.set("spark.authenticate.secret", "bad")
+ val securityManagerBad = new SecurityManager(badconf);
+
+ assert(securityManagerBad.isAuthenticationEnabled() === false)
+
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
+ conf = badconf, securityManager = securityManagerBad)
+ val slaveTracker = new MapOutputTracker(conf)
+ val selection = slaveSystem.actorSelection(
+ s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ intercept[akka.actor.ActorNotFound] {
+ slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout)
+ }
+
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index e022accee6d08..96ba3929c1685 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
+
override def afterEach() {
super.afterEach()
System.clearProperty("spark.broadcast.factory")
diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
new file mode 100644
index 0000000000000..80f7ec00c74b2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.FunSuite
+
+import java.nio._
+
+import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId}
+import scala.concurrent.Await
+import scala.concurrent.TimeoutException
+import scala.concurrent.duration._
+
+
+/**
+ * Test the ConnectionManager with various security settings.
+ */
+class ConnectionManagerSuite extends FunSuite {
+
+ test("security default off") {
+ val conf = new SparkConf
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var receivedMessage = false
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ receivedMessage = true
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+
+ assert(receivedMessage == true)
+
+ manager.stop()
+ }
+
+ test("security on same password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+ val managerServer = new ConnectionManager(0, conf, securityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+ })
+
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch password") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "true")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "bad")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(managerServer.id, bufferMessage)
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security mismatch auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ conf.set("spark.authenticate.secret", "good")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "true")
+ badconf.set("spark.authenticate.secret", "good")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 1).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ assert(false)
+ } catch {
+ case e: TimeoutException => {
+ // we should timeout here since the client can't do the negotiation
+ assert(true)
+ }
+ }
+ })
+
+ assert(numReceivedServerMessages == 0)
+ assert(numReceivedMessages == 0)
+ manager.stop()
+ managerServer.stop()
+ }
+
+ test("security auth off") {
+ val conf = new SparkConf
+ conf.set("spark.authenticate", "false")
+ val securityManager = new SecurityManager(conf)
+ val manager = new ConnectionManager(0, conf, securityManager)
+ var numReceivedMessages = 0
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedMessages += 1
+ None
+ })
+
+ val badconf = new SparkConf
+ badconf.set("spark.authenticate", "false")
+ val badsecurityManager = new SecurityManager(badconf)
+ val managerServer = new ConnectionManager(0, badconf, badsecurityManager)
+ var numReceivedServerMessages = 0
+
+ managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ numReceivedServerMessages += 1
+ None
+ })
+
+ val size = 10 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ (0 until 10).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(managerServer.id, bufferMessage)
+ }).foreach(f => {
+ try {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) assert(false) else assert(true)
+ } catch {
+ case e: Exception => {
+ assert(false)
+ }
+ }
+ })
+ assert(numReceivedServerMessages == 10)
+ assert(numReceivedMessages == 0)
+
+ manager.stop()
+ managerServer.stop()
+ }
+
+
+
+}
+
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index e0e8011278649..9cbdfc54a3dc8 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -30,6 +30,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.Utils
class DriverSuite extends FunSuite with Timeouts {
+
test("driver should exit after finishing") {
val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 9be67b3c95abd..aee9ab9091dac 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -30,6 +30,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
@transient var tmpFile: File = _
@transient var tmpJarUrl: String = _
+ override def beforeEach() {
+ super.beforeEach()
+ resetSparkContext()
+ System.setProperty("spark.authenticate", "false")
+ }
+
override def beforeAll() {
super.beforeAll()
val tmpDir = new File(Files.createTempDir(), "test")
@@ -43,6 +49,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
val jarFile = new File(tmpDir, "test.jar")
val jarStream = new FileOutputStream(jarFile)
val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest())
+ System.setProperty("spark.authenticate", "false")
val jarEntry = new JarEntry(textFile.getName)
jar.putNextEntry(jarEntry)
@@ -77,6 +84,25 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
assert(result.toSet === Set((1,200), (2,300), (3,500)))
}
+ test("Distributing files locally security On") {
+ val sparkConf = new SparkConf(false)
+ sparkConf.set("spark.authenticate", "true")
+ sparkConf.set("spark.authenticate.secret", "good")
+ sc = new SparkContext("local[4]", "test", sparkConf)
+
+ sc.addFile(tmpFile.toString)
+ assert(sc.env.securityManager.isAuthenticationEnabled() === true)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect()
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
test("Distributing files locally using URL as input") {
// addFile("file:///....")
sc = new SparkContext("local[4]", "test")
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index 8ff02aef67aa0..01af94077144a 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -24,6 +24,9 @@ import scala.io.Source
import com.google.common.io.Files
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.DefaultCodec
+import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, TextOutputFormat}
+import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat}
+import org.apache.hadoop.mapreduce.Job
import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
@@ -208,4 +211,70 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(rdd.count() === 3)
assert(rdd.count() === 3)
}
+
+ test ("prevent user from overwriting the empty directory (old Hadoop API)") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
+ intercept[FileAlreadyExistsException] {
+ randomRDD.saveAsTextFile(tempdir.getPath)
+ }
+ }
+
+ test ("prevent user from overwriting the non-empty directory (old Hadoop API)") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
+ randomRDD.saveAsTextFile(tempdir.getPath + "/output")
+ assert(new File(tempdir.getPath + "/output/part-00000").exists() === true)
+ intercept[FileAlreadyExistsException] {
+ randomRDD.saveAsTextFile(tempdir.getPath + "/output")
+ }
+ }
+
+ test ("prevent user from overwriting the empty directory (new Hadoop API)") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1)
+ intercept[FileAlreadyExistsException] {
+ randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempdir.getPath)
+ }
+ }
+
+ test ("prevent user from overwriting the non-empty directory (new Hadoop API)") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1)
+ randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempdir.getPath + "/output")
+ assert(new File(tempdir.getPath + "/output/part-r-00000").exists() === true)
+ intercept[FileAlreadyExistsException] {
+ randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](tempdir.getPath)
+ }
+ }
+
+ test ("save Hadoop Dataset through old Hadoop API") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1)
+ val job = new JobConf()
+ job.setOutputKeyClass(classOf[String])
+ job.setOutputValueClass(classOf[String])
+ job.set("mapred.output.format.class", classOf[TextOutputFormat[String, String]].getName)
+ job.set("mapred.output.dir", tempdir.getPath + "/outputDataset_old")
+ randomRDD.saveAsHadoopDataset(job)
+ assert(new File(tempdir.getPath + "/outputDataset_old/part-00000").exists() === true)
+ }
+
+ test ("save Hadoop Dataset through new Hadoop API") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ val randomRDD = sc.parallelize(Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1)
+ val job = new Job(sc.hadoopConfiguration)
+ job.setOutputKeyClass(classOf[String])
+ job.setOutputValueClass(classOf[String])
+ job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]])
+ job.getConfiguration.set("mapred.output.dir", tempdir.getPath + "/outputDataset_new")
+ randomRDD.saveAsNewAPIHadoopDataset(job.getConfiguration)
+ assert(new File(tempdir.getPath + "/outputDataset_new/part-r-00000").exists() === true)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6c1e325f6f348..a5bd72eb0a122 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import scala.concurrent.Await
import akka.actor._
+import akka.testkit.TestActorRef
import org.scalatest.FunSuite
import org.apache.spark.scheduler.MapStatus
@@ -51,14 +52,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+ tracker.trackerActor =
+ actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+ tracker.trackerActor =
+ actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -77,7 +80,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+ tracker.trackerActor =
+ actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -98,14 +102,18 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("remote fetch") {
val hostname = "localhost"
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf)
- System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
+ securityManager = new SecurityManager(conf))
+
+ // Will be cleared by LocalSparkContext
+ System.setProperty("spark.driver.port", boundPort.toString)
val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+ Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")
- val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf)
+ val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
+ securityManager = new SecurityManager(conf))
val slaveTracker = new MapOutputTracker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
@@ -124,7 +132,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
- Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
+ Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
masterTracker.incrementEpoch()
@@ -134,4 +142,44 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// failure should be cached
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
}
+
+ test("remote fetch below akka frame size") {
+ val newConf = new SparkConf
+ newConf.set("spark.akka.frameSize", "1")
+ newConf.set("spark.akka.askTimeout", "1") // Fail fast
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ val actorSystem = ActorSystem("test")
+ val actorRef = TestActorRef[MapOutputTrackerMasterActor](
+ new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
+ val masterActor = actorRef.underlyingActor
+
+ // Frame size should be ~123B, and no exception should be thrown
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0)))
+ masterActor.receive(GetMapOutputStatuses(10))
+ }
+
+ test("remote fetch exceeds akka frame size") {
+ val newConf = new SparkConf
+ newConf.set("spark.akka.frameSize", "1")
+ newConf.set("spark.akka.askTimeout", "1") // Fail fast
+
+ val masterTracker = new MapOutputTrackerMaster(conf)
+ val actorSystem = ActorSystem("test")
+ val actorRef = TestActorRef[MapOutputTrackerMasterActor](
+ new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
+ val masterActor = actorRef.underlyingActor
+
+ // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception.
+ // Note that the size is hand-selected here because map output statuses are compressed before
+ // being sent.
+ masterTracker.registerShuffle(20, 100)
+ (0 until 100).foreach { i =>
+ masterTracker.registerMapOutput(20, i, new MapStatus(
+ BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0)))
+ }
+ intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 4305686d3a6d5..996db70809320 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -171,6 +171,8 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet
assert(abs(6.0/2 - rdd.mean) < 0.01)
assert(abs(1.0 - rdd.variance) < 0.01)
assert(abs(1.0 - rdd.stdev) < 0.01)
+ assert(stats.max === 4.0)
+ assert(stats.min === 2.0)
// Add other tests here for classes that should be able to handle empty partitions correctly
}
diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
index 3a0385a1b0bd9..0bac78d8a6bdf 100644
--- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
@@ -19,74 +19,152 @@ package org.apache.spark
import org.scalatest.FunSuite
+
+import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition}
+import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
+import org.apache.hadoop.fs.Path
+
+import scala.collection.Map
+import scala.sys.process._
+import scala.util.Try
+import org.apache.hadoop.io.{Text, LongWritable}
+
class PipedRDDSuite extends FunSuite with SharedSparkContext {
test("basic pipe") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("cat"))
+ val piped = nums.pipe(Seq("cat"))
- val c = piped.collect()
- assert(c.size === 4)
- assert(c(0) === "1")
- assert(c(1) === "2")
- assert(c(2) === "3")
- assert(c(3) === "4")
+ val c = piped.collect()
+ assert(c.size === 4)
+ assert(c(0) === "1")
+ assert(c(1) === "2")
+ assert(c(2) === "3")
+ assert(c(3) === "4")
+ } else {
+ assert(true)
+ }
}
test("advanced pipe") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val bl = sc.broadcast(List("0"))
-
- val piped = nums.pipe(Seq("cat"),
- Map[String, String](),
- (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
- (i:Int, f: String=> Unit) => f(i + "_"))
-
- val c = piped.collect()
-
- assert(c.size === 8)
- assert(c(0) === "0")
- assert(c(1) === "\u0001")
- assert(c(2) === "1_")
- assert(c(3) === "2_")
- assert(c(4) === "0")
- assert(c(5) === "\u0001")
- assert(c(6) === "3_")
- assert(c(7) === "4_")
-
- val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
- val d = nums1.groupBy(str=>str.split("\t")(0)).
- pipe(Seq("cat"),
- Map[String, String](),
- (f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
- (i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
- assert(d.size === 8)
- assert(d(0) === "0")
- assert(d(1) === "\u0001")
- assert(d(2) === "b\t2_")
- assert(d(3) === "b\t4_")
- assert(d(4) === "0")
- assert(d(5) === "\u0001")
- assert(d(6) === "a\t1_")
- assert(d(7) === "a\t3_")
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val bl = sc.broadcast(List("0"))
+
+ val piped = nums.pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {
+ bl.value.map(f(_)); f("\u0001")
+ },
+ (i: Int, f: String => Unit) => f(i + "_"))
+
+ val c = piped.collect()
+
+ assert(c.size === 8)
+ assert(c(0) === "0")
+ assert(c(1) === "\u0001")
+ assert(c(2) === "1_")
+ assert(c(3) === "2_")
+ assert(c(4) === "0")
+ assert(c(5) === "\u0001")
+ assert(c(6) === "3_")
+ assert(c(7) === "4_")
+
+ val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
+ val d = nums1.groupBy(str => str.split("\t")(0)).
+ pipe(Seq("cat"),
+ Map[String, String](),
+ (f: String => Unit) => {
+ bl.value.map(f(_)); f("\u0001")
+ },
+ (i: Tuple2[String, Seq[String]], f: String => Unit) => {
+ for (e <- i._2) {
+ f(e + "_")
+ }
+ }).collect()
+ assert(d.size === 8)
+ assert(d(0) === "0")
+ assert(d(1) === "\u0001")
+ assert(d(2) === "b\t2_")
+ assert(d(3) === "b\t4_")
+ assert(d(4) === "0")
+ assert(d(5) === "\u0001")
+ assert(d(6) === "a\t1_")
+ assert(d(7) === "a\t3_")
+ } else {
+ assert(true)
+ }
}
test("pipe with env variable") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
- val c = piped.collect()
- assert(c.size === 2)
- assert(c(0) === "LALALA")
- assert(c(1) === "LALALA")
+ if (testCommandAvailable("printenv")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
+ val c = piped.collect()
+ assert(c.size === 2)
+ assert(c(0) === "LALALA")
+ assert(c(1) === "LALALA")
+ } else {
+ assert(true)
+ }
}
test("pipe with non-zero exit status") {
- val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
- intercept[SparkException] {
- piped.collect()
+ if (testCommandAvailable("cat")) {
+ val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val piped = nums.pipe(Seq("cat nonexistent_file", "2>", "/dev/null"))
+ intercept[SparkException] {
+ piped.collect()
+ }
+ } else {
+ assert(true)
}
}
+ test("test pipe exports map_input_file") {
+ testExportInputFile("map_input_file")
+ }
+
+ test("test pipe exports mapreduce_map_input_file") {
+ testExportInputFile("mapreduce_map_input_file")
+ }
+
+ def testCommandAvailable(command: String): Boolean = {
+ Try(Process(command) !!).isSuccess
+ }
+
+ def testExportInputFile(varName: String) {
+ if (testCommandAvailable("printenv")) {
+ val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
+ classOf[Text], 2) {
+ override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition())
+
+ override val getDependencies = List[Dependency[_]]()
+
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1),
+ new Text("b"))))
+ }
+ }
+ val hadoopPart1 = generateFakeHadoopPartition()
+ val pipedRdd = new PipedRDD(nums, "printenv " + varName)
+ val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
+ val rddIter = pipedRdd.compute(hadoopPart1, tContext)
+ val arr = rddIter.toArray
+ assert(arr(0) == "/some/path")
+ } else {
+ // printenv isn't available so just pass the test
+ assert(true)
+ }
+ }
+
+ def generateFakeHadoopPartition(): HadoopPartition = {
+ val split = new FileSplit(new Path("/some/path"), 0, 1,
+ Array[String]("loc1", "loc2", "loc3", "loc4", "loc5"))
+ new HadoopPartition(sc.newRddId(), 1, split)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index abea36f7c83df..be6508a40ea61 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -27,6 +27,9 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+
+ val conf = new SparkConf(loadDefaults = false)
+
test("groupByKey without compression") {
try {
System.setProperty("spark.shuffle.compress", "false")
@@ -54,7 +57,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
- b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[KryoSerializer].getName)
+ b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 10)
@@ -76,7 +79,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// If the Kryo serializer is not used correctly, the shuffle would fail because the
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
- b, new HashPartitioner(3)).setSerializer(classOf[KryoSerializer].getName)
+ b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf))
assert(c.count === 10)
}
@@ -92,7 +95,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
- .setSerializer(classOf[KryoSerializer].getName)
+ .setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 4)
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index f28d5c7b133b3..3bb936790d506 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -95,6 +95,10 @@ class SparkContextSchedulerCreationSuite
}
}
+ test("yarn-cluster") {
+ testYarn("yarn-cluster", "org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ }
+
test("yarn-standalone") {
testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler")
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index de866ed7ffed8..bae3b37e267d5 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -20,9 +20,12 @@ package org.apache.spark.deploy
import java.io.File
import java.util.Date
-import net.liftweb.json.Diff
-import net.liftweb.json.{JsonAST, JsonParser}
-import net.liftweb.json.JsonAST.{JNothing, JValue}
+import org.json4s._
+
+import org.json4s.JValue
+import org.json4s.jackson.JsonMethods
+import com.fasterxml.jackson.core.JsonParseException
+
import org.scalatest.FunSuite
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
@@ -34,31 +37,31 @@ class JsonProtocolSuite extends FunSuite {
test("writeApplicationInfo") {
val output = JsonProtocol.writeApplicationInfo(createAppInfo())
assertValidJson(output)
- assertValidDataInJson(output, JsonParser.parse(JsonConstants.appInfoJsonStr))
+ assertValidDataInJson(output, JsonMethods.parse(JsonConstants.appInfoJsonStr))
}
test("writeWorkerInfo") {
val output = JsonProtocol.writeWorkerInfo(createWorkerInfo())
assertValidJson(output)
- assertValidDataInJson(output, JsonParser.parse(JsonConstants.workerInfoJsonStr))
+ assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerInfoJsonStr))
}
test("writeApplicationDescription") {
val output = JsonProtocol.writeApplicationDescription(createAppDesc())
assertValidJson(output)
- assertValidDataInJson(output, JsonParser.parse(JsonConstants.appDescJsonStr))
+ assertValidDataInJson(output, JsonMethods.parse(JsonConstants.appDescJsonStr))
}
test("writeExecutorRunner") {
val output = JsonProtocol.writeExecutorRunner(createExecutorRunner())
assertValidJson(output)
- assertValidDataInJson(output, JsonParser.parse(JsonConstants.executorRunnerJsonStr))
+ assertValidDataInJson(output, JsonMethods.parse(JsonConstants.executorRunnerJsonStr))
}
test("writeDriverInfo") {
val output = JsonProtocol.writeDriverInfo(createDriverInfo())
assertValidJson(output)
- assertValidDataInJson(output, JsonParser.parse(JsonConstants.driverInfoJsonStr))
+ assertValidDataInJson(output, JsonMethods.parse(JsonConstants.driverInfoJsonStr))
}
test("writeMasterState") {
@@ -71,7 +74,7 @@ class JsonProtocolSuite extends FunSuite {
activeDrivers, completedDrivers, RecoveryState.ALIVE)
val output = JsonProtocol.writeMasterState(stateResponse)
assertValidJson(output)
- assertValidDataInJson(output, JsonParser.parse(JsonConstants.masterStateJsonStr))
+ assertValidDataInJson(output, JsonMethods.parse(JsonConstants.masterStateJsonStr))
}
test("writeWorkerState") {
@@ -83,7 +86,7 @@ class JsonProtocolSuite extends FunSuite {
finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl")
val output = JsonProtocol.writeWorkerState(stateResponse)
assertValidJson(output)
- assertValidDataInJson(output, JsonParser.parse(JsonConstants.workerStateJsonStr))
+ assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr))
}
def createAppDesc(): ApplicationDescription = {
@@ -125,9 +128,9 @@ class JsonProtocolSuite extends FunSuite {
def assertValidJson(json: JValue) {
try {
- JsonParser.parse(JsonAST.compactRender(json))
+ JsonMethods.parse(JsonMethods.compact(json))
} catch {
- case e: JsonParser.ParseException => fail("Invalid Json detected", e)
+ case e: JsonParseException => fail("Invalid Json detected", e)
}
}
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
index c1e8b295dfe3b..96a5a1231813e 100644
--- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -18,21 +18,22 @@
package org.apache.spark.metrics
import org.scalatest.{BeforeAndAfter, FunSuite}
-
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.master.MasterSource
class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
var filePath: String = _
var conf: SparkConf = null
+ var securityMgr: SecurityManager = null
before {
filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile()
conf = new SparkConf(false).set("spark.metrics.conf", filePath)
+ securityMgr = new SecurityManager(conf)
}
test("MetricsSystem with default config") {
- val metricsSystem = MetricsSystem.createMetricsSystem("default", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr)
val sources = metricsSystem.sources
val sinks = metricsSystem.sinks
@@ -42,7 +43,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter {
}
test("MetricsSystem with sources add") {
- val metricsSystem = MetricsSystem.createMetricsSystem("test", conf)
+ val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr)
val sources = metricsSystem.sources
val sinks = metricsSystem.sinks
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index e3e23775f011d..f9e994b13dfbc 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -347,6 +347,48 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
*/
pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored")
}
+
+ test("lookup") {
+ val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7)))
+
+ assert(pairs.partitioner === None)
+ assert(pairs.lookup(1) === Seq(2))
+ assert(pairs.lookup(5) === Seq(6,7))
+ assert(pairs.lookup(-1) === Seq())
+
+ }
+
+ test("lookup with partitioner") {
+ val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7)))
+
+ val p = new Partitioner {
+ def numPartitions: Int = 2
+
+ def getPartition(key: Any): Int = Math.abs(key.hashCode() % 2)
+ }
+ val shuffled = pairs.partitionBy(p)
+
+ assert(shuffled.partitioner === Some(p))
+ assert(shuffled.lookup(1) === Seq(2))
+ assert(shuffled.lookup(5) === Seq(6,7))
+ assert(shuffled.lookup(-1) === Seq())
+ }
+
+ test("lookup with bad partitioner") {
+ val pairs = sc.parallelize(Array((1,2), (3,4), (5,6), (5,7)))
+
+ val p = new Partitioner {
+ def numPartitions: Int = 2
+
+ def getPartition(key: Any): Int = key.hashCode() % 2
+ }
+ val shuffled = pairs.partitionBy(p)
+
+ assert(shuffled.partitioner === Some(p))
+ assert(shuffled.lookup(1) === Seq(2))
+ intercept[IllegalArgumentException] {shuffled.lookup(-1)}
+ }
+
}
/*
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 60bcada55245b..d6b5fdc7984b4 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -47,6 +47,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
+ assert(nums.max() === 4)
+ assert(nums.min() === 1)
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))
@@ -457,6 +459,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
test("takeSample") {
val data = sc.parallelize(1 to 100, 2)
+
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
@@ -488,6 +491,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("takeSample from an empty rdd") {
+ val emptySet = sc.parallelize(Seq.empty[Int], 2)
+ val sample = emptySet.takeSample(false, 20, 1)
+ assert(sample.length === 0)
+ }
+
test("randomSplit") {
val n = 600
val data = sc.parallelize(1 to n, 2)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 0b90c4e74c8a4..0a7cb69416a08 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -24,3 +24,19 @@ class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
+
+object FakeTask {
+ /**
+ * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
+ * locations for each task (given as varargs) if this sequence is not empty.
+ */
+ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+ if (prefLocs.size != 0 && prefLocs.size != numTasks) {
+ throw new IllegalArgumentException("Wrong number of task locations")
+ }
+ val tasks = Array.tabulate[Task[_]](numTasks) { i =>
+ new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
+ }
+ new TaskSet(tasks, 0, 0, 0, null)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 368c5154ea3b9..7c4f2b4361892 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -129,7 +129,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
sm.localBlocksFetched should be > (0)
sm.remoteBlocksFetched should be (0)
sm.remoteBytesRead should be (0l)
- sm.remoteFetchTime should be (0l)
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index ac07f60e284bb..c4e7a4bb7d385 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -93,10 +93,10 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
// If this test hangs, it's probably because no resource offers were made after the task
// failed.
val scheduler: TaskSchedulerImpl = sc.taskScheduler match {
- case clusterScheduler: TaskSchedulerImpl =>
- clusterScheduler
+ case taskScheduler: TaskSchedulerImpl =>
+ taskScheduler
case _ =>
- assert(false, "Expect local cluster to use ClusterScheduler")
+ assert(false, "Expect local cluster to use TaskSchedulerImpl")
throw new ClassCastException
}
scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
similarity index 66%
rename from core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
rename to core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 85e929925e3b5..6b0800af9c6d0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -25,13 +25,20 @@ import org.scalatest.FunSuite
import org.apache.spark._
+class FakeSchedulerBackend extends SchedulerBackend {
+ def start() {}
+ def stop() {}
+ def reviveOffers() {}
+ def defaultParallelism() = 1
+}
+
class FakeTaskSetManager(
initPriority: Int,
initStageId: Int,
initNumTasks: Int,
- clusterScheduler: TaskSchedulerImpl,
+ taskScheduler: TaskSchedulerImpl,
taskSet: TaskSet)
- extends TaskSetManager(clusterScheduler, taskSet, 0) {
+ extends TaskSetManager(taskScheduler, taskSet, 0) {
parent = null
weight = 1
@@ -105,9 +112,10 @@ class FakeTaskSetManager(
}
}
-class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
+class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Logging {
- def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl, taskSet: TaskSet): FakeTaskSetManager = {
+ def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskSchedulerImpl,
+ taskSet: TaskSet): FakeTaskSetManager = {
new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
}
@@ -133,20 +141,17 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
}
test("FIFO Scheduler Test") {
- sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new TaskSchedulerImpl(sc)
- var tasks = ArrayBuffer[Task[_]]()
- val task = new FakeTask(0)
- tasks += task
- val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ val taskSet = FakeTask.createTaskSet(1)
val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0)
val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
schedulableBuilder.buildPools()
- val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet)
- val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet)
- val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet)
+ val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet)
+ val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet)
+ val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet)
schedulableBuilder.addTaskSetManager(taskSetManager0, null)
schedulableBuilder.addTaskSetManager(taskSetManager1, null)
schedulableBuilder.addTaskSetManager(taskSetManager2, null)
@@ -160,12 +165,9 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
}
test("Fair Scheduler Test") {
- sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new TaskSchedulerImpl(sc)
- var tasks = ArrayBuffer[Task[_]]()
- val task = new FakeTask(0)
- tasks += task
- val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ val taskSet = FakeTask.createTaskSet(1)
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.scheduler.allocation.file", xmlPath)
@@ -189,15 +191,15 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
val properties2 = new Properties()
properties2.setProperty("spark.scheduler.pool","2")
- val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet)
- val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet)
- val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet)
+ val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet)
+ val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet)
+ val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet)
schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
- val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet)
- val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet)
+ val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet)
+ val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet)
schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
@@ -217,12 +219,9 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
}
test("Nested Pool Test") {
- sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new TaskSchedulerImpl(sc)
- var tasks = ArrayBuffer[Task[_]]()
- val task = new FakeTask(0)
- tasks += task
- val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ val taskSet = FakeTask.createTaskSet(1)
val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1)
@@ -240,23 +239,23 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
pool1.addSchedulable(pool10)
pool1.addSchedulable(pool11)
- val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet)
- val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet)
+ val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet)
+ val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet)
pool00.addSchedulable(taskSetManager000)
pool00.addSchedulable(taskSetManager001)
- val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet)
- val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet)
+ val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet)
+ val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet)
pool01.addSchedulable(taskSetManager010)
pool01.addSchedulable(taskSetManager011)
- val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet)
- val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet)
+ val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet)
+ val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet)
pool10.addSchedulable(taskSetManager100)
pool10.addSchedulable(taskSetManager101)
- val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet)
- val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet)
+ val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet)
+ val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet)
pool11.addSchedulable(taskSetManager110)
pool11.addSchedulable(taskSetManager111)
@@ -265,4 +264,35 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
checkTaskSetId(rootPool, 6)
checkTaskSetId(rootPool, 2)
}
+
+ test("Scheduler does not always schedule tasks on the same workers") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ var dagScheduler = new DAGScheduler(taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorGained(execId: String, host: String) {}
+ }
+
+ val numFreeCores = 1
+ val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores),
+ new WorkerOffer("executor1", "host1", numFreeCores))
+ // Repeatedly try to schedule a 1-task job, and make sure that it doesn't always
+ // get scheduled on the same executor. While there is a chance this test will fail
+ // because the task randomly gets placed on the first executor all 1000 times, the
+ // probability of that happening is 2^-1000 (so sufficiently small to be considered
+ // negligible).
+ val numTrials = 1000
+ val selectedExecutorIds = 1.to(numTrials).map { _ =>
+ val taskSet = FakeTask.createTaskSet(1)
+ taskScheduler.submitTasks(taskSet)
+ val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(1 === taskDescriptions.length)
+ taskDescriptions(0).executorId
+ }
+ var count = selectedExecutorIds.count(_ == workerOffers(0).executorId)
+ assert(count > 0)
+ assert(count < numTrials)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 34a7d8cefeea2..33cc7588b919c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.FakeClock
-class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
+class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
taskScheduler.startedTasks += taskInfo.index
}
@@ -51,12 +51,12 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler
}
/**
- * A mock ClusterScheduler implementation that just remembers information about tasks started and
+ * A mock TaskSchedulerImpl implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with
* a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
* to work, and these are required for locality in TaskSetManager.
*/
-class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
+class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
extends TaskSchedulerImpl(sc)
{
val startedTasks = new ArrayBuffer[Long]
@@ -87,8 +87,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("TaskSet with no preferences") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
- val taskSet = createTaskSet(1)
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+ val taskSet = FakeTask.createTaskSet(1)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
// Offer a host with no CPUs
@@ -113,8 +113,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("multiple offers with no preferences") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
- val taskSet = createTaskSet(3)
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+ val taskSet = FakeTask.createTaskSet(3)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
// First three offers should all find tasks
@@ -144,8 +144,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("basic delay scheduling") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
- val taskSet = createTaskSet(4,
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val taskSet = FakeTask.createTaskSet(4,
Seq(TaskLocation("host1", "exec1")),
Seq(TaskLocation("host2", "exec2")),
Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")),
@@ -188,9 +188,9 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("delay scheduling with fallback") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc,
+ val sched = new FakeTaskScheduler(sc,
("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
- val taskSet = createTaskSet(5,
+ val taskSet = FakeTask.createTaskSet(5,
Seq(TaskLocation("host1")),
Seq(TaskLocation("host2")),
Seq(TaskLocation("host2")),
@@ -228,8 +228,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("delay scheduling with failed hosts") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
- val taskSet = createTaskSet(3,
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val taskSet = FakeTask.createTaskSet(3,
Seq(TaskLocation("host1")),
Seq(TaskLocation("host2")),
Seq(TaskLocation("host3"))
@@ -260,8 +260,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("task result lost") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
- val taskSet = createTaskSet(1)
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+ val taskSet = FakeTask.createTaskSet(1)
val clock = new FakeClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
@@ -277,8 +277,8 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("repeated failures lead to task set abortion") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
- val taskSet = createTaskSet(1)
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+ val taskSet = FakeTask.createTaskSet(1)
val clock = new FakeClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
@@ -298,21 +298,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
}
}
-
- /**
- * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
- * locations for each task (given as varargs) if this sequence is not empty.
- */
- def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
- if (prefLocs.size != 0 && prefLocs.size != numTasks) {
- throw new IllegalArgumentException("Wrong number of task locations")
- }
- val tasks = Array.tabulate[Task[_]](numTasks) { i =>
- new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
- }
- new TaskSet(tasks, 0, 0, 0, null)
- }
-
def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 9f011d9c8d132..1036b9f34e9dd 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
@@ -39,6 +39,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
var actorSystem: ActorSystem = null
var master: BlockManagerMaster = null
var oldArch: String = null
+ conf.set("spark.authenticate", "false")
+ val securityMgr = new SecurityManager(conf)
// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
conf.set("spark.kryoserializer.buffer.mb", "1")
@@ -49,7 +51,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
before {
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf,
+ securityManager = securityMgr)
this.actorSystem = actorSystem
conf.set("spark.driver.port", boundPort.toString)
@@ -125,7 +128,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 1 manager interaction") {
- store = new BlockManager("", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -155,8 +158,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf)
- store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf,
+ securityMgr)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -171,7 +175,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager("", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -219,7 +223,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing rdd") {
- store = new BlockManager("", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -253,7 +257,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -269,7 +273,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager("", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@@ -288,7 +292,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration doesn't dead lock") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = List(new Array[Byte](400))
@@ -325,7 +329,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -344,7 +348,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -363,7 +367,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -382,7 +386,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -405,7 +409,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -418,7 +422,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -433,7 +437,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -448,7 +452,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -463,7 +467,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -478,7 +482,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -503,7 +507,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -527,7 +531,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager("", actorSystem, master, serializer, 1200, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -573,7 +577,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager("", actorSystem, master, serializer, 500, conf)
+ store = new BlockManager("", actorSystem, master, serializer, 500, conf, securityMgr)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -584,7 +588,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
conf.set("spark.shuffle.compress", "true")
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
"shuffle_0_0_0 was not compressed")
@@ -592,7 +596,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.shuffle.compress", "false")
- store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
"shuffle_0_0_0 was compressed")
@@ -600,7 +604,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "true")
- store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
"broadcast_0 was not compressed")
@@ -608,28 +612,28 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "false")
- store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "true")
- store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "false")
- store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
@@ -643,7 +647,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
- store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf)
+ store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf,
+ securityMgr)
// The put should fail since a1 is not serializable.
class UnserializableClass
@@ -657,4 +662,18 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.getSingle("a1") == None, "a1 should not be in store")
}
}
+
+ test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") {
+ store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ // Access rdd_1_0 to ensure it's not least recently used.
+ assert(store.getSingle(rdd(1, 0)).isDefined, "rdd_1_0 was not in store")
+ // According to the same-RDD rule, rdd_1_0 should be replaced here.
+ store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ // rdd_1_0 should have been replaced, even it's not least recently used.
+ assert(store.memoryStore.contains(rdd(0, 0)), "rdd_0_0 was not in store")
+ assert(store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was not in store")
+ assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store")
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala
new file mode 100644
index 0000000000000..bcf138b5ee6d0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/FlatmapIteratorSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.scalatest.FunSuite
+import org.apache.spark.{SharedSparkContext, SparkConf, LocalSparkContext, SparkContext}
+
+
+class FlatmapIteratorSuite extends FunSuite with LocalSparkContext {
+ /* Tests the ability of Spark to deal with user provided iterators from flatMap
+ * calls, that may generate more data then available memory. In any
+ * memory based persistance Spark will unroll the iterator into an ArrayBuffer
+ * for caching, however in the case that the use defines DISK_ONLY persistance,
+ * the iterator will be fed directly to the serializer and written to disk.
+ *
+ * This also tests the ObjectOutputStream reset rate. When serializing using the
+ * Java serialization system, the serializer caches objects to prevent writing redundant
+ * data, however that stops GC of those objects. By calling 'reset' you flush that
+ * info from the serializer, and allow old objects to be GC'd
+ */
+ test("Flatmap Iterator to Disk") {
+ val sconf = new SparkConf().setMaster("local").setAppName("iterator_to_disk_test")
+ sc = new SparkContext(sconf)
+ val expand_size = 100
+ val data = sc.parallelize((1 to 5).toSeq).
+ flatMap( x => Stream.range(0, expand_size))
+ var persisted = data.persist(StorageLevel.DISK_ONLY)
+ assert(persisted.count()===500)
+ assert(persisted.filter(_==1).count()===5)
+ }
+
+ test("Flatmap Iterator to Memory") {
+ val sconf = new SparkConf().setMaster("local").setAppName("iterator_to_disk_test")
+ sc = new SparkContext(sconf)
+ val expand_size = 100
+ val data = sc.parallelize((1 to 5).toSeq).
+ flatMap(x => Stream.range(0, expand_size))
+ var persisted = data.persist(StorageLevel.MEMORY_ONLY)
+ assert(persisted.count()===500)
+ assert(persisted.filter(_==1).count()===5)
+ }
+
+ test("Serializer Reset") {
+ val sconf = new SparkConf().setMaster("local").setAppName("serializer_reset_test")
+ .set("spark.serializer.objectStreamReset", "10")
+ sc = new SparkContext(sconf)
+ val expand_size = 500
+ val data = sc.parallelize(Seq(1,2)).
+ flatMap(x => Stream.range(1, expand_size).
+ map(y => "%d: string test %d".format(y,x)))
+ var persisted = data.persist(StorageLevel.MEMORY_ONLY_SER)
+ assert(persisted.filter(_.startsWith("1:")).count()===2)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 20ebb1897e6ba..30415814adbba 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -24,6 +24,8 @@ import scala.util.{Failure, Success, Try}
import org.eclipse.jetty.server.Server
import org.scalatest.FunSuite
+import org.apache.spark.SparkConf
+
class UISuite extends FunSuite {
test("jetty port increases under contention") {
val startPort = 4040
@@ -34,15 +36,17 @@ class UISuite extends FunSuite {
case Failure(e) =>
// Either case server port is busy hence setup for test complete
}
- val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq())
- val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq())
+ val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(),
+ new SparkConf)
+ val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(),
+ new SparkConf)
// Allow some wiggle room in case ports on the machine are under contention
assert(boundPort1 > startPort && boundPort1 < startPort + 10)
assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10)
}
test("jetty binds to port 0 correctly") {
- val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq())
+ val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq(), new SparkConf)
assert(jettyServer.getState === "STARTED")
assert(boundPort != 0)
Try {new ServerSocket(boundPort)} match {
diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
index c51d12bfe0bc6..757476efdb789 100644
--- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
@@ -72,4 +72,8 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
}
+ test ("XORShift with zero seed") {
+ val random = new XORShiftRandom(0L)
+ assert(random.nextInt() != 0)
+ }
}
diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md
new file mode 100644
index 0000000000000..2437a98672177
--- /dev/null
+++ b/dev/audit-release/README.md
@@ -0,0 +1,11 @@
+# Test Application Builds
+This directory includes test applications which are built when auditing releases. You can
+run them locally by setting appropriate environment variables.
+
+```
+$ cd sbt_app_core
+$ SCALA_VERSION=2.10.3 \
+ SPARK_VERSION=1.0.0-SNAPSHOT \
+ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \
+ sbt run
+```
diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py
index 4408658f5e33f..52c367d9b030d 100755
--- a/dev/audit-release/audit_release.py
+++ b/dev/audit-release/audit_release.py
@@ -31,10 +31,10 @@
import urllib2
## Fill in release details here:
-RELEASE_URL = "http://people.apache.org/~pwendell/spark-0.9.0-incubating-rc5/"
+RELEASE_URL = "http://people.apache.org/~pwendell/spark-1.0.0-rc1/"
RELEASE_KEY = "9E4FE3AF"
RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/"
-RELEASE_VERSION = "0.9.0-incubating"
+RELEASE_VERSION = "1.0.0"
SCALA_VERSION = "2.10.3"
SCALA_BINARY_VERSION = "2.10"
##
@@ -191,10 +191,6 @@ def ensure_path_not_present(x):
test("NOTICE" in base_files, "Tarball contains NOTICE file")
test("LICENSE" in base_files, "Tarball contains LICENSE file")
- os.chdir(os.path.join(WORK_DIR, dir_name))
- readme = "".join(open("README.md").readlines())
- disclaimer_part = "is an effort undergoing incubation"
- test(disclaimer_part in readme, "README file contains disclaimer")
os.chdir(WORK_DIR)
for artifact in artifacts:
diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala
index d49de8b73a856..53fe43215e40e 100644
--- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala
+++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala
@@ -17,6 +17,8 @@
package main.scala
+import scala.util.Try
+
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
@@ -31,6 +33,17 @@ object SimpleApp {
println("Failed to parse log files with Spark")
System.exit(-1)
}
- println("Test succeeded")
+
+ // Regression test for SPARK-1167: Remove metrics-ganglia from default build due to LGPL issue
+ val foundConsole = Try(Class.forName("org.apache.spark.metrics.sink.ConsoleSink")).isSuccess
+ val foundGanglia = Try(Class.forName("org.apache.spark.metrics.sink.GangliaSink")).isSuccess
+ if (!foundConsole) {
+ println("Console sink not loaded via spark-core")
+ System.exit(-1)
+ }
+ if (foundGanglia) {
+ println("Ganglia sink was loaded via spark-core")
+ System.exit(-1)
+ }
}
}
diff --git a/dev/audit-release/sbt_app_ganglia/build.sbt b/dev/audit-release/sbt_app_ganglia/build.sbt
new file mode 100644
index 0000000000000..55db675c722d1
--- /dev/null
+++ b/dev/audit-release/sbt_app_ganglia/build.sbt
@@ -0,0 +1,31 @@
+//
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+name := "Ganglia Test"
+
+version := "1.0"
+
+scalaVersion := System.getenv.get("SCALA_VERSION")
+
+libraryDependencies += "org.apache.spark" %% "spark-core" % System.getenv.get("SPARK_VERSION")
+
+libraryDependencies += "org.apache.spark" %% "spark-ganglia-lgpl" % System.getenv.get("SPARK_VERSION")
+
+resolvers ++= Seq(
+ "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"),
+ "Akka Repository" at "http://repo.akka.io/releases/",
+ "Spray Repository" at "http://repo.spray.cc/")
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala
similarity index 53%
rename from core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala
rename to dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala
index d0ba0b6307ee9..0be8e64fbfabd 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.scala
+++ b/dev/audit-release/sbt_app_ganglia/src/main/scala/SparkApp.scala
@@ -15,19 +15,25 @@
* limitations under the License.
*/
-package org.apache.spark.api.java.function
+package main.scala
-import scala.reflect.ClassTag
-import org.apache.spark.api.java.JavaSparkContext
+import scala.util.Try
-/**
- * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs.
- */
-// PairFunction does not extend Function because some UDF functions, like map,
-// are overloaded for both Function and PairFunction.
-abstract class PairFunction[T, K, V] extends WrappedFunction1[T, (K, V)] with Serializable {
-
- def keyType(): ClassTag[K] = JavaSparkContext.fakeClassTag
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
- def valueType(): ClassTag[V] = JavaSparkContext.fakeClassTag
+object SimpleApp {
+ def main(args: Array[String]) {
+ // Regression test for SPARK-1167: Remove metrics-ganglia from default build due to LGPL issue
+ val foundConsole = Try(Class.forName("org.apache.spark.metrics.sink.ConsoleSink")).isSuccess
+ val foundGanglia = Try(Class.forName("org.apache.spark.metrics.sink.GangliaSink")).isSuccess
+ if (!foundConsole) {
+ println("Console sink not loaded via spark-core")
+ System.exit(-1)
+ }
+ if (!foundGanglia) {
+ println("Ganglia sink not loaded via spark-ganglia-lgpl")
+ System.exit(-1)
+ }
+ }
}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 7cebace5069f8..995106f111443 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -39,8 +39,8 @@ GIT_TAG=v$RELEASE_VERSION
# Artifact publishing
-git clone https://git-wip-us.apache.org/repos/asf/incubator-spark.git -b $GIT_BRANCH
-cd incubator-spark
+git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH
+cd spark
export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g"
mvn -Pyarn release:clean
@@ -49,21 +49,21 @@ mvn -DskipTests \
-Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Pyarn \
+ -Pyarn -Pspark-ganglia-lgpl \
-Dtag=$GIT_TAG -DautoVersionSubmodules=true \
--batch-mode release:prepare
mvn -DskipTests \
-Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
-Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Pyarn \
+ -Pyarn -Pspark-ganglia-lgpl\
release:perform
-rm -rf incubator-spark
+rm -rf spark
# Source and binary tarballs
-git clone https://git-wip-us.apache.org/repos/asf/incubator-spark.git
-cd incubator-spark
+git clone https://git-wip-us.apache.org/repos/asf/spark.git
+cd spark
git checkout --force $GIT_TAG
release_hash=`git rev-parse HEAD`
@@ -71,7 +71,7 @@ rm .gitignore
rm -rf .git
cd ..
-cp -r incubator-spark spark-$RELEASE_VERSION
+cp -r spark spark-$RELEASE_VERSION
tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION
echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \
--detach-sig spark-$RELEASE_VERSION.tgz
@@ -85,7 +85,7 @@ make_binary_release() {
NAME=$1
MAVEN_FLAGS=$2
- cp -r incubator-spark spark-$RELEASE_VERSION-bin-$NAME
+ cp -r spark spark-$RELEASE_VERSION-bin-$NAME
cd spark-$RELEASE_VERSION-bin-$NAME
export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g"
mvn $MAVEN_FLAGS -DskipTests clean package
@@ -118,9 +118,9 @@ scp spark* \
$USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/
# Docs
-cd incubator-spark
+cd spark
cd docs
-jekyll build
+PRODUCTION=1 jekyll build
echo "Copying release documentation"
rc_docs_folder=${rc_folder}-docs
rsync -r _site/* $USER_NAME@people.apache.org /home/$USER_NAME/public_html/$rc_docs_folder
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index 93621c96daf2d..e8f78fc5f231a 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -38,7 +38,7 @@
# Remote name which points to Apache git
PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache")
-GIT_API_BASE = "https://api.github.com/repos/apache/incubator-spark"
+GIT_API_BASE = "https://api.github.com/repos/apache/spark"
# Prefix added to temporary branches
BRANCH_PREFIX = "PR_TOOL"
diff --git a/dev/run-tests b/dev/run-tests
index d65a397b4c8c7..cf0b940c09a81 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -27,6 +27,16 @@ rm -rf ./work
# Fail fast
set -e
+if test -x "$JAVA_HOME/bin/java"; then
+ declare java_cmd="$JAVA_HOME/bin/java"
+else
+ declare java_cmd=java
+fi
+
+JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q')
+[ "$JAVA_VERSION" -ge 18 ] && echo "" || echo "[Warn] Java 8 tests will not run, because JDK version is < 1.8."
+
+
echo "========================================================================="
echo "Running Scala style checks"
echo "========================================================================="
diff --git a/docker/README.md b/docker/README.md
index bf59e77d111f9..40ba9c3065946 100644
--- a/docker/README.md
+++ b/docker/README.md
@@ -2,4 +2,6 @@ Spark docker files
===========
Drawn from Matt Massie's docker files (https://github.com/massie/dockerfiles),
-as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker).
\ No newline at end of file
+as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker).
+
+Tested with Docker version 0.8.1.
diff --git a/docker/spark-test/master/default_cmd b/docker/spark-test/master/default_cmd
index a5b1303c2ebdb..5a7da3446f6d2 100755
--- a/docker/spark-test/master/default_cmd
+++ b/docker/spark-test/master/default_cmd
@@ -19,4 +19,10 @@
IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }')
echo "CONTAINER_IP=$IP"
-/opt/spark/spark-class org.apache.spark.deploy.master.Master -i $IP
+export SPARK_LOCAL_IP=$IP
+export SPARK_PUBLIC_DNS=$IP
+
+# Avoid the default Docker behavior of mapping our IP address to an unreachable host name
+umount /etc/hosts
+
+/opt/spark/bin/spark-class org.apache.spark.deploy.master.Master -i $IP
diff --git a/docker/spark-test/worker/default_cmd b/docker/spark-test/worker/default_cmd
index ab6336f70c1c6..31b06cb0eb047 100755
--- a/docker/spark-test/worker/default_cmd
+++ b/docker/spark-test/worker/default_cmd
@@ -19,4 +19,10 @@
IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }')
echo "CONTAINER_IP=$IP"
-/opt/spark/spark-class org.apache.spark.deploy.worker.Worker $1
+export SPARK_LOCAL_IP=$IP
+export SPARK_PUBLIC_DNS=$IP
+
+# Avoid the default Docker behavior of mapping our IP address to an unreachable host name
+umount /etc/hosts
+
+/opt/spark/bin/spark-class org.apache.spark.deploy.worker.Worker $1
diff --git a/docs/README.md b/docs/README.md
index cc09d6e88f41e..0678fc5c86706 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -1,6 +1,6 @@
Welcome to the Spark documentation!
-This readme will walk you through navigating and building the Spark documentation, which is included here with the Spark source code. You can also find documentation specific to release versions of Spark at http://spark.incubator.apache.org/documentation.html.
+This readme will walk you through navigating and building the Spark documentation, which is included here with the Spark source code. You can also find documentation specific to release versions of Spark at http://spark.apache.org/documentation.html.
Read on to learn more about viewing documentation in plain text (i.e., markdown) or building the documentation yourself. Why build it yourself? So that you have the docs that corresponds to whichever version of Spark you currently have checked out of revision control.
@@ -10,9 +10,22 @@ We include the Spark documentation as part of the source (as opposed to using a
In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can read those text files directly if you want. Start with index.md.
-To make things quite a bit prettier and make the links easier to follow, generate the html version of the documentation based on the src directory by running `jekyll build` in the docs directory. Use the command `SKIP_SCALADOC=1 jekyll build` to skip building and copying over the scaladoc which can be timely. To use the `jekyll` command, you will need to have Jekyll installed, the easiest way to do this is via a Ruby Gem, see the [jekyll installation instructions](http://jekyllrb.com/docs/installation). This will create a directory called _site containing index.html as well as the rest of the compiled files. Read more about Jekyll at https://github.com/mojombo/jekyll/wiki.
-
-In addition to generating the site as html from the markdown files, jekyll can serve up the site via a webserver. To build and run a local webserver use the command `jekyll serve` (or the faster variant `SKIP_SCALADOC=1 jekyll serve`), which runs the webserver on port 4000, then visit the site at http://localhost:4000.
+The markdown code can be compiled to HTML using the
+[Jekyll tool](http://jekyllrb.com).
+To use the `jekyll` command, you will need to have Jekyll installed.
+The easiest way to do this is via a Ruby Gem, see the
+[jekyll installation instructions](http://jekyllrb.com/docs/installation).
+Compiling the site with Jekyll will create a directory called
+_site containing index.html as well as the rest of the compiled files.
+
+You can modify the default Jekyll build as follows:
+
+ # Skip generating API docs (which takes a while)
+ $ SKIP_SCALADOC=1 jekyll build
+ # Serve content locally on port 4000
+ $ jekyll serve --watch
+ # Build the site with extra features used on the live page
+ $ PRODUCTION=1 jekyll build
## Pygments
diff --git a/docs/_config.yml b/docs/_config.yml
index 9e5a95fe53af6..aa5a5adbc1743 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -3,10 +3,10 @@ markdown: kramdown
# These allow the documentation to be updated with nerw releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 1.0.0-incubating-SNAPSHOT
+SPARK_VERSION: 1.0.0-SNAPSHOT
SPARK_VERSION_SHORT: 1.0.0
SCALA_BINARY_VERSION: "2.10"
SCALA_VERSION: "2.10.3"
MESOS_VERSION: 0.13.0
SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net
-SPARK_GITHUB_URL: https://github.com/apache/incubator-spark
+SPARK_GITHUB_URL: https://github.com/apache/spark
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 7114e1f5dd5b9..49fd78ca98655 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -24,9 +24,9 @@
+ {% production %}
+ {% endproduction %}
@@ -159,16 +159,6 @@
Heading
-->
-
-
diff --git a/docs/_plugins/production_tag.rb b/docs/_plugins/production_tag.rb
new file mode 100644
index 0000000000000..9f870cf2137af
--- /dev/null
+++ b/docs/_plugins/production_tag.rb
@@ -0,0 +1,14 @@
+module Jekyll
+ class ProductionTag < Liquid::Block
+
+ def initialize(tag_name, markup, tokens)
+ super
+ end
+
+ def render(context)
+ if ENV['PRODUCTION'] then super else "" end
+ end
+ end
+end
+
+Liquid::Template.register_tag('production', Jekyll::ProductionTag)
diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md
index b070d8e73a38b..da6d0c9dcd97b 100644
--- a/docs/bagel-programming-guide.md
+++ b/docs/bagel-programming-guide.md
@@ -108,7 +108,7 @@ _Example_
## Operations
-Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/incubator-spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details.
+Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details.
### Actions
diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md
index ded12926885b9..730a6e7932564 100644
--- a/docs/building-with-maven.md
+++ b/docs/building-with-maven.md
@@ -25,6 +25,8 @@ If you don't run this, you may see errors like the following:
You can fix this by setting the `MAVEN_OPTS` variable as discussed before.
+*Note: For Java 1.8 and above this step is not required.*
+
## Specifying the Hadoop version ##
Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the "hadoop.version" property. If unset, Spark will build against Hadoop 1.0.4 by default.
@@ -54,7 +56,7 @@ Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.o
The ScalaTest plugin also supports running only a specific test suite as follows:
- $ mvn -Dhadoop.version=... -Dsuites=spark.repl.ReplSuite test
+ $ mvn -Dhadoop.version=... -Dsuites=org.apache.spark.repl.ReplSuite test
## Continuous Compilation ##
@@ -76,3 +78,19 @@ The maven build includes support for building a Debian package containing the as
$ mvn -Pdeb -DskipTests clean package
The debian package can then be found under assembly/target. We added the short commit hash to the file name so that we can distinguish individual packages built for SNAPSHOT versions.
+
+## Running java 8 test suites.
+
+Running only java 8 tests and nothing else.
+
+ $ mvn install -DskipTests -Pjava8-tests
+
+Java 8 tests are run when -Pjava8-tests profile is enabled, they will run in spite of -DskipTests.
+For these tests to run your system must have a JDK 8 installation.
+If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests.
+
+## Packaging without Hadoop dependencies for deployment on YARN ##
+
+The assembly jar produced by "mvn package" will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with yarn.application.classpath. The "hadoop-provided" profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself.
+
+
diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md
index e16703292cc22..a555a7b5023e3 100644
--- a/docs/cluster-overview.md
+++ b/docs/cluster-overview.md
@@ -13,7 +13,7 @@ object in your main program (called the _driver program_).
Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_
(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across
applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are
-worker processes that run computations and store data for your application.
+processes that run computations and store data for your application.
Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to
the executors. Finally, SparkContext sends *tasks* for the executors to run.
diff --git a/docs/configuration.md b/docs/configuration.md
index 8e4c48c81f8be..a006224d5080c 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -147,6 +147,34 @@ Apart from these, the following properties are also available, and may be useful
How many stages the Spark UI remembers before garbage collecting.
+
+
spark.ui.filters
+
None
+
+ Comma separated list of filter class names to apply to the Spark web ui. The filter should be a
+ standard javax servlet Filter. Parameters to each filter can also be specified by setting a
+ java system property of spark.<class name of filter>.params='param1=value1,param2=value2'
+ (e.g.-Dspark.ui.filters=com.test.filter1 -Dspark.com.test.filter1.params='param1=foo,param2=testing')
+
+
+
+
spark.ui.acls.enable
+
false
+
+ Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has
+ access permissions to view the web ui. See spark.ui.view.acls for more details.
+ Also note this requires the user to be known, if the user comes across as null no checks
+ are done. Filters can be used to authenticate and set the user.
+
+
+
+
spark.ui.view.acls
+
Empty
+
+ Comma separated list of users that have view access to the spark web ui. By default only the
+ user that started the Spark job has view access.
+
+
spark.shuffle.compress
true
@@ -201,6 +229,13 @@ Apart from these, the following properties are also available, and may be useful
multi-user services.
+
+
spark.scheduler.revive.interval
+
1000
+
+ The interval length for the scheduler to revive the worker resource offers to run tasks. (in milliseconds)
+
+
spark.reducer.maxMbInFlight
48
@@ -237,6 +272,17 @@ Apart from these, the following properties are also available, and may be useful
exceeded" exception inside Kryo. Note that there will be one buffer per core on each worker.
+
+
spark.serializer.objectStreamReset
+
10000
+
+ When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
+ objects to prevent writing redundant data, however that stops garbage collection of those
+ objects. By calling 'reset' you flush that info from the serializer, and allow old
+ objects to be collected. To turn off this periodic reset set it to a value of <= 0.
+ By default it will reset the serializer every 10,000 objects.
+
+
spark.broadcast.factory
org.apache.spark.broadcast. HttpBroadcastFactory
@@ -469,7 +515,7 @@ Apart from these, the following properties are also available, and may be useful
the whole cluster by default. Note: this setting needs to be configured in the standalone cluster master, not in individual
applications; you can set it through SPARK_JAVA_OPTS in spark-env.sh.
-
+
spark.files.overwrite
@@ -478,6 +524,38 @@ Apart from these, the following properties are also available, and may be useful
Whether to overwrite files added through SparkContext.addFile() when the target file exists and its contents do not match those of the source.
+
+
spark.files.fetchTimeout
+
false
+
+ Communication timeout to use when fetching files added through SparkContext.addFile() from
+ the driver.
+
+
+
+
spark.authenticate
+
false
+
+ Whether spark authenticates its internal connections. See spark.authenticate.secret if not
+ running on Yarn.
+
+
+
+
spark.authenticate.secret
+
None
+
+ Set the secret key used for Spark to authenticate between components. This needs to be set if
+ not running on Yarn and authentication is enabled.
+
+
+
+
spark.core.connection.auth.wait.timeout
+
30
+
+ Number of seconds for the connection to wait for authentication to occur before timing
+ out and giving up.
+
+
## Viewing Spark Properties
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index 3dfed7bea9ea8..1238e3e0a4e7d 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -135,7 +135,7 @@ Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Chan
structure of the graph are accomplished by producing a new graph with the desired changes. Note
that substantial parts of the original graph (i.e., unaffected structure, attributes, and indicies)
are reused in the new graph reducing the cost of this inherently functional data-structure. The
-graph is partitioned across the workers using a range of vertex-partitioning heuristics. As with
+graph is partitioned across the executors using a range of vertex-partitioning heuristics. As with
RDDs, each partition of the graph can be recreated on a different machine in the event of a failure.
Logically the property graph corresponds to a pair of typed collections (RDDs) encoding the
diff --git a/docs/index.md b/docs/index.md
index aa9c8666e7d75..23311101e1712 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -9,7 +9,7 @@ It also supports a rich set of higher-level tools including [Shark](http://shark
# Downloading
-Get Spark by visiting the [downloads page](http://spark.incubator.apache.org/downloads.html) of the Apache Spark site. This documentation is for Spark version {{site.SPARK_VERSION}}.
+Get Spark by visiting the [downloads page](http://spark.apache.org/downloads.html) of the Apache Spark site. This documentation is for Spark version {{site.SPARK_VERSION}}.
Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). All you need to run it is to have `java` to installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation.
@@ -23,10 +23,12 @@ For its Scala API, Spark {{site.SPARK_VERSION}} depends on Scala {{site.SCALA_BI
# Running the Examples and Shell
-Spark comes with several sample programs in the `examples` directory.
-To run one of the samples, use `./bin/run-example ` in the top-level Spark directory
+Spark comes with several sample programs. Scala and Java examples are in the `examples` directory, and Python examples are in `python/examples`.
+To run one of the Java or Scala sample programs, use `./bin/run-example ` in the top-level Spark directory
(the `bin/run-example` script sets up the appropriate paths and launches that program).
For example, try `./bin/run-example org.apache.spark.examples.SparkPi local`.
+To run a Python sample program, use `./bin/pyspark `. For example, try `./bin/pyspark ./python/examples/pi.py local`.
+
Each example prints usage help when run with no parameters.
Note that all of the sample programs take a `` parameter specifying the cluster URL
@@ -96,13 +98,14 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui
* [Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes
* [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager
* [Mesos](running-on-mesos.html): deploy a private cluster using
- [Apache Mesos](http://incubator.apache.org/mesos)
+ [Apache Mesos](http://mesos.apache.org)
* [YARN](running-on-yarn.html): deploy Spark on top of Hadoop NextGen (YARN)
**Other documents:**
* [Configuration](configuration.html): customize Spark via its configuration system
* [Tuning Guide](tuning.html): best practices to optimize performance and memory use
+* [Security](security.html): Spark security support
* [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware
* [Job Scheduling](job-scheduling.html): scheduling resources across and within Spark applications
* [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system
@@ -110,20 +113,20 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui
**External resources:**
-* [Spark Homepage](http://spark.incubator.apache.org)
+* [Spark Homepage](http://spark.apache.org)
* [Shark](http://shark.cs.berkeley.edu): Apache Hive over Spark
-* [Mailing Lists](http://spark.incubator.apache.org/mailing-lists.html): ask questions about Spark here
+* [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here
* [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and
exercises about Spark, Shark, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/agenda-2012),
[slides](http://ampcamp.berkeley.edu/agenda-2012) and [exercises](http://ampcamp.berkeley.edu/exercises-2012) are
available online for free.
-* [Code Examples](http://spark.incubator.apache.org/examples.html): more are also available in the [examples subfolder](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/) of Spark
+* [Code Examples](http://spark.apache.org/examples.html): more are also available in the [examples subfolder](https://github.com/apache/spark/tree/master/examples/src/main/scala/) of Spark
* [Paper Describing Spark](http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf)
* [Paper Describing Spark Streaming](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf)
# Community
-To get help using Spark or keep up with Spark development, sign up for the [user mailing list](http://spark.incubator.apache.org/mailing-lists.html).
+To get help using Spark or keep up with Spark development, sign up for the [user mailing list](http://spark.apache.org/mailing-lists.html).
If you're in the San Francisco Bay Area, there's a regular [Spark meetup](http://www.meetup.com/spark-users/) every few weeks. Come by to meet the developers and other users.
diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md
index 07732fa1229f3..6632360f6e3ca 100644
--- a/docs/java-programming-guide.md
+++ b/docs/java-programming-guide.md
@@ -21,15 +21,21 @@ operations (e.g. map) and handling RDDs of different types, as discussed next.
There are a few key differences between the Java and Scala APIs:
-* Java does not support anonymous or first-class functions, so functions must
- be implemented by extending the
+* Java does not support anonymous or first-class functions, so functions are passed
+ using anonymous classes that implement the
[`org.apache.spark.api.java.function.Function`](api/core/index.html#org.apache.spark.api.java.function.Function),
[`Function2`](api/core/index.html#org.apache.spark.api.java.function.Function2), etc.
- classes.
+ interfaces.
* To maintain type safety, the Java API defines specialized Function and RDD
classes for key-value pairs and doubles. For example,
[`JavaPairRDD`](api/core/index.html#org.apache.spark.api.java.JavaPairRDD)
stores key-value pairs.
+* Some methods are defined on the basis of the passed anonymous function's
+ (a.k.a lambda expression) return type,
+ for example mapToPair(...) or flatMapToPair returns
+ [`JavaPairRDD`](api/core/index.html#org.apache.spark.api.java.JavaPairRDD),
+ similarly mapToDouble and flatMapToDouble returns
+ [`JavaDoubleRDD`](api/core/index.html#org.apache.spark.api.java.JavaDoubleRDD).
* RDD methods like `collect()` and `countByKey()` return Java collections types,
such as `java.util.List` and `java.util.Map`.
* Key-value pairs, which are simply written as `(key, value)` in Scala, are represented
@@ -53,10 +59,10 @@ each specialized RDD class, so filtering a `PairRDD` returns a new `PairRDD`,
etc (this acheives the "same-result-type" principle used by the [Scala collections
framework](http://docs.scala-lang.org/overviews/core/architecture-of-scala-collections.html)).
-## Function Classes
+## Function Interfaces
-The following table lists the function classes used by the Java API. Each
-class has a single abstract method, `call()`, that must be implemented.
+The following table lists the function interfaces used by the Java API. Each
+interface has a single abstract method, `call()`, that must be implemented.
Class
Function Type
@@ -78,7 +84,6 @@ RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, suc
declared in the [org.apache.spark.api.java.StorageLevels](api/core/index.html#org.apache.spark.api.java.StorageLevels) class. To
define your own storage level, you can use StorageLevels.create(...).
-
# Other Features
The Java API supports other Spark features, including
@@ -86,6 +91,21 @@ The Java API supports other Spark features, including
[broadcast variables](scala-programming-guide.html#broadcast-variables), and
[caching](scala-programming-guide.html#rdd-persistence).
+# Upgrading From Pre-1.0 Versions of Spark
+
+In version 1.0 of Spark the Java API was refactored to better support Java 8
+lambda expressions. Users upgrading from older versions of Spark should note
+the following changes:
+
+* All `org.apache.spark.api.java.function.*` have been changed from abstract
+ classes to interfaces. This means that concrete implementations of these
+ `Function` classes will need to use `implements` rather than `extends`.
+* Certain transformation functions now have multiple versions depending
+ on the return type. In Spark core, the map functions (map, flatMap,
+ mapPartitons) have type-specific versions, e.g.
+ [`mapToPair`](api/core/index.html#org.apache.spark.api.java.JavaRDD@mapToPair[K2,V2](f:org.apache.spark.api.java.function.PairFunction[T,K2,V2]):org.apache.spark.api.java.JavaPairRDD[K2,V2])
+ and [`mapToDouble`](api/core/index.html#org.apache.spark.api.java.JavaRDD@mapToDouble[R](f:org.apache.spark.api.java.function.DoubleFunction[T]):org.apache.spark.api.java.JavaDoubleRDD).
+ Spark Streaming also uses the same approach, e.g. [`transformToPair`](api/streaming/index.html#org.apache.spark.streaming.api.java.JavaDStream@transformToPair[K2,V2](transformFunc:org.apache.spark.api.java.function.Function[R,org.apache.spark.api.java.JavaPairRDD[K2,V2]]):org.apache.spark.streaming.api.java.JavaPairDStream[K2,V2]).
# Example
@@ -127,11 +147,20 @@ class Split extends FlatMapFunction {
JavaRDD words = lines.flatMap(new Split());
{% endhighlight %}
+Java 8+ users can also write the above `FlatMapFunction` in a more concise way using
+a lambda expression:
+
+{% highlight java %}
+JavaRDD words = lines.flatMap(s -> Arrays.asList(s.split(" ")));
+{% endhighlight %}
+
+This lambda syntax can be applied to all anonymous classes in Java 8.
+
Continuing with the word count example, we map each word to a `(word, 1)` pair:
{% highlight java %}
import scala.Tuple2;
-JavaPairRDD ones = words.map(
+JavaPairRDD ones = words.mapToPair(
new PairFunction() {
public Tuple2 call(String s) {
return new Tuple2(s, 1);
@@ -140,7 +169,7 @@ JavaPairRDD ones = words.map(
);
{% endhighlight %}
-Note that `map` was passed a `PairFunction` and
+Note that `mapToPair` was passed a `PairFunction` and
returned a `JavaPairRDD`.
To finish the word count program, we will use `reduceByKey` to count the
@@ -164,7 +193,7 @@ possible to chain the RDD transformations, so the word count example could also
be written as:
{% highlight java %}
-JavaPairRDD counts = lines.flatMap(
+JavaPairRDD counts = lines.flatMapToPair(
...
).map(
...
@@ -180,16 +209,17 @@ just a matter of style.
We currently provide documentation for the Java API as Scaladoc, in the
[`org.apache.spark.api.java` package](api/core/index.html#org.apache.spark.api.java.package), because
-some of the classes are implemented in Scala. The main downside is that the types and function
+some of the classes are implemented in Scala. It is important to note that the types and function
definitions show Scala syntax (for example, `def reduce(func: Function2[T, T]): T` instead of
-`T reduce(Function2 func)`).
-We hope to generate documentation with Java-style syntax in the future.
+`T reduce(Function2 func)`). In addition, the Scala `trait` modifier is used for Java
+interface classes. We hope to generate documentation with Java-style syntax in the future to
+avoid these quirks.
# Where to Go from Here
Spark includes several sample programs using the Java API in
-[`examples/src/main/java`](https://github.com/apache/incubator-spark/tree/master/examples/src/main/java/org/apache/spark/examples). You can run them by passing the class name to the
+[`examples/src/main/java`](https://github.com/apache/spark/tree/master/examples/src/main/java/org/apache/spark/examples). You can run them by passing the class name to the
`bin/run-example` script included in Spark; for example:
./bin/run-example org.apache.spark.examples.JavaWordCount
diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md
index df2faa5e41b18..94604f301dd46 100644
--- a/docs/job-scheduling.md
+++ b/docs/job-scheduling.md
@@ -39,8 +39,8 @@ Resource allocation can be configured as follows, based on the cluster type:
* **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`,
and optionally set `spark.cores.max` to limit each application's resource share as in the standalone mode.
You should also set `spark.executor.memory` to control the executor memory.
-* **YARN:** The `--num-workers` option to the Spark YARN client controls how many workers it will allocate
- on the cluster, while `--worker-memory` and `--worker-cores` control the resources per worker.
+* **YARN:** The `--num-executors` option to the Spark YARN client controls how many executors it will allocate
+ on the cluster, while `--executor-memory` and `--executor-cores` control the resources per executor.
A second option available on Mesos is _dynamic sharing_ of CPU cores. In this mode, each Spark application
still has a fixed and independent memory allocation (set by `spark.executor.memory`), but when the
diff --git a/docs/js/main.js b/docs/js/main.js
index 102699789a71a..0bd2286cced19 100755
--- a/docs/js/main.js
+++ b/docs/js/main.js
@@ -1,26 +1,3 @@
-
-// From docs.scala-lang.org
-function styleCode() {
- if (typeof disableStyleCode != "undefined") {
- return;
- }
- $(".codetabs pre code").parent().each(function() {
- if (!$(this).hasClass("prettyprint")) {
- var lang = $(this).parent().data("lang");
- if (lang == "python") {
- lang = "py"
- }
- if (lang == "bash") {
- lang = "bsh"
- }
- $(this).addClass("prettyprint lang-"+lang+" linenums");
- }
- });
- console.log("runningPrettyPrint()")
- prettyPrint();
-}
-
-
function codeTabs() {
var counter = 0;
var langImages = {
@@ -97,11 +74,7 @@ function viewSolution() {
}
-$(document).ready(function() {
+$(function() {
codeTabs();
viewSolution();
- $('#chapter-toc').toc({exclude: '', context: '.container'});
- $('#chapter-toc').prepend('
In This Chapter
');
- makeCollapsable($('#global-toc'), "", "global-toc", "Show Table of Contents");
- //styleCode();
});
diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md
index 18a3e8e075086..d5bd8042ca2ec 100644
--- a/docs/mllib-classification-regression.md
+++ b/docs/mllib-classification-regression.md
@@ -77,8 +77,8 @@ between the two goals of small loss and small model complexity.
**Distributed Datasets.**
For all currently implemented optimization methods for classification, the data must be
-distributed between the worker machines *by examples*. Every machine holds a consecutive block of
-the `$n$` example/label pairs `$(\x_i,y_i)$`.
+distributed between processes on the worker machines *by examples*. Machines hold consecutive
+blocks of the `$n$` example/label pairs `$(\x_i,y_i)$`.
In other words, the input distributed dataset
([RDD](scala-programming-guide.html#resilient-distributed-datasets-rdds)) must be the set of
vectors `$\x_i\in\R^d$`.
diff --git a/docs/monitoring.md b/docs/monitoring.md
index e9b1d2b2f4ffb..15bfb041780da 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -48,11 +48,22 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the
* `ConsoleSink`: Logs metrics information to the console.
* `CSVSink`: Exports metrics data to CSV files at regular intervals.
-* `GangliaSink`: Sends metrics to a Ganglia node or multicast group.
* `JmxSink`: Registers metrics for viewing in a JXM console.
* `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data.
* `GraphiteSink`: Sends metrics to a Graphite node.
+Spark also supports a Ganglia sink which is not included in the default build due to
+licensing restrictions:
+
+* `GangliaSink`: Sends metrics to a Ganglia node or multicast group.
+
+To install the `GangliaSink` you'll need to perform a custom build of Spark. _**Note that
+by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed
+code in your Spark package**_. For sbt users, set the
+`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable
+the `-Pspark-ganglia-lgpl` profile. In addition to modifying the cluster's Spark build
+user applications will need to link to the `spark-ganglia-lgpl` artifact.
+
The syntax of the metrics configuration file is defined in an example configuration file,
`$SPARK_HOME/conf/metrics.properties.template`.
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index 7c5283fb0b6fb..cbe7d820b455e 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -43,9 +43,9 @@ def is_error(line):
errors = logData.filter(is_error)
{% endhighlight %}
-PySpark will automatically ship these functions to workers, along with any objects that they reference.
-Instances of classes will be serialized and shipped to workers by PySpark, but classes themselves cannot be automatically distributed to workers.
-The [Standalone Use](#standalone-use) section describes how to ship code dependencies to workers.
+PySpark will automatically ship these functions to executors, along with any objects that they reference.
+Instances of classes will be serialized and shipped to executors by PySpark, but classes themselves cannot be automatically distributed to executors.
+The [Standalone Use](#standalone-use) section describes how to ship code dependencies to executors.
In addition, PySpark fully supports interactive use---simply run `./bin/pyspark` to launch an interactive shell.
@@ -157,7 +157,7 @@ some example applications.
# Where to Go from Here
-PySpark also includes several sample programs in the [`python/examples` folder](https://github.com/apache/incubator-spark/tree/master/python/examples).
+PySpark also includes several sample programs in the [`python/examples` folder](https://github.com/apache/spark/tree/master/python/examples).
You can run them by passing the files to `pyspark`; e.g.:
./bin/pyspark python/examples/wordcount.py
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index cd4509ede735a..2e9dec4856ee9 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -29,7 +29,7 @@ If you want to test out the YARN deployment mode, you can use the current Spark
# Configuration
-Most of the configs are the same for Spark on YARN as other deploys. See the Configuration page for more information on those. These are configs that are specific to SPARK on YARN.
+Most of the configs are the same for Spark on YARN as for other deployment modes. See the Configuration page for more information on those. These are configs that are specific to Spark on YARN.
Environment variables:
@@ -41,28 +41,29 @@ System Properties:
* `spark.yarn.submit.file.replication`, the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives.
* `spark.yarn.preserve.staging.files`, set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them.
* `spark.yarn.scheduler.heartbeat.interval-ms`, the interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. Default is 5 seconds.
-* `spark.yarn.max.worker.failures`, the maximum number of worker failures before failing the application. Default is the number of workers requested times 2 with minimum of 3.
+* `spark.yarn.max.executor.failures`, the maximum number of executor failures before failing the application. Default is the number of executors requested times 2 with minimum of 3.
# Launching Spark on YARN
-Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster.
-This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager.
+Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the Hadoop cluster.
+These configs are used to connect to the cluster, write to the dfs, and connect to the YARN ResourceManager.
-There are two scheduler mode that can be used to launch spark application on YARN.
+There are two scheduler modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN.
-## Launch spark application by YARN Client with yarn-standalone mode.
+Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster".
-The command to launch the YARN Client is as follows:
+## Launching a Spark application with yarn-cluster mode.
+
+The command to launch the Spark application on the cluster is as follows:
SPARK_JAR= ./bin/spark-class org.apache.spark.deploy.yarn.Client \
--jar \
--class \
--args \
- --num-workers \
- --master-class
- --master-memory \
- --worker-memory \
- --worker-cores \
+ --num-executors \
+ --driver-memory \
+ --executor-memory \
+ --executor-cores \
--name \
--queue \
--addJars \
@@ -82,58 +83,61 @@ For example:
./bin/spark-class org.apache.spark.deploy.yarn.Client \
--jar examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
--class org.apache.spark.examples.SparkPi \
- --args yarn-standalone \
- --num-workers 3 \
- --master-memory 4g \
- --worker-memory 2g \
- --worker-cores 1
-
- # Examine the output (replace $YARN_APP_ID in the following with the "application identifier" output by the previous command)
- # (Note: YARN_APP_LOGS_DIR is usually /tmp/logs or $HADOOP_HOME/logs/userlogs depending on the Hadoop version.)
- $ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout
- Pi is roughly 3.13794
+ --args yarn-cluster \
+ --num-executors 3 \
+ --driver-memory 4g \
+ --executor-memory 2g \
+ --executor-cores 1
-The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs.
-With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell.
+Because the application is run on a remote machine where the Application Master is running, applications that involve local interaction, such as spark-shell, will not work.
-## Launch spark application with yarn-client mode.
+## Launching a Spark application with yarn-client mode.
-With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR
+With yarn-client mode, the application will be launched locally, just like running an application or spark-shell on Local / Mesos / Standalone client mode. The launch method is also the same, just make sure to specify the master URL as "yarn-client". You also need to export the env value for SPARK_JAR.
Configuration in yarn-client mode:
-In order to tune worker core/number/memory etc. You need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options.
+In order to tune executor cores/number/memory etc., you need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options.
-* `SPARK_YARN_APP_JAR`, Path to your application's JAR file (required)
-* `SPARK_WORKER_INSTANCES`, Number of workers to start (Default: 2)
-* `SPARK_WORKER_CORES`, Number of cores for the workers (Default: 1).
-* `SPARK_WORKER_MEMORY`, Memory per Worker (e.g. 1000M, 2G) (Default: 1G)
-* `SPARK_MASTER_MEMORY`, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)
+* `SPARK_EXECUTOR_INSTANCES`, Number of executors to start (Default: 2)
+* `SPARK_EXECUTOR_CORES`, Number of cores per executor (Default: 1).
+* `SPARK_EXECUTOR_MEMORY`, Memory per executor (e.g. 1000M, 2G) (Default: 1G)
+* `SPARK_DRIVER_MEMORY`, Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)
* `SPARK_YARN_APP_NAME`, The name of your application (Default: Spark)
-* `SPARK_YARN_QUEUE`, The hadoop queue to use for allocation requests (Default: 'default')
+* `SPARK_YARN_QUEUE`, The YARN queue to use for allocation requests (Default: 'default')
* `SPARK_YARN_DIST_FILES`, Comma separated list of files to be distributed with the job.
* `SPARK_YARN_DIST_ARCHIVES`, Comma separated list of archives to be distributed with the job.
For example:
SPARK_JAR=./assembly/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
- SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
./bin/run-example org.apache.spark.examples.SparkPi yarn-client
+or
SPARK_JAR=./assembly/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
- SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
MASTER=yarn-client ./bin/spark-shell
+## Viewing logs
+
+In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the yarn.log-aggregation-enable config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command.
+
+ yarn logs -applicationId
+
+will print out the contents of all log files from all containers from the given application.
+
+When log aggregation isn't turned on, logs are retained locally on each machine under YARN_APP_LOGS_DIR, which is usually configured to /tmp/logs or $HADOOP_HOME/logs/userlogs depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID.
+
# Building Spark for Hadoop/YARN 2.2.x
-See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using the Maven process.
+See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using Maven.
-# Important Notes
+# Important notes
- Before Hadoop 2.2, YARN does not support cores in container resource requests. Thus, when running against an earlier version, the numbers of cores given via command line arguments cannot be passed to YARN. Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured.
-- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
-- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN.
+- The local directories used by Spark executors will be the local directories configured for YARN (Hadoop YARN config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
+- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt, and your application should use the name as appSees.txt to reference it when running on YARN.
- The --addJars option allows the SparkContext.addJar function to work if you are using it with local files. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files.
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 506d3faa767f3..99412733d4268 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -365,7 +365,7 @@ res2: Int = 10
# Where to Go from Here
-You can see some [example Spark programs](http://spark.incubator.apache.org/examples.html) on the Spark website.
+You can see some [example Spark programs](http://spark.apache.org/examples.html) on the Spark website.
In addition, Spark includes several samples in `examples/src/main/scala`. Some of them have both Spark versions and local (non-parallel) versions, allowing you to see what had to be changed to make the program run on a cluster. You can run them using by passing the class name to the `bin/run-example` script included in Spark; for example:
./bin/run-example org.apache.spark.examples.SparkPi
diff --git a/docs/security.md b/docs/security.md
new file mode 100644
index 0000000000000..9e4218fbcfe7d
--- /dev/null
+++ b/docs/security.md
@@ -0,0 +1,18 @@
+---
+layout: global
+title: Spark Security
+---
+
+Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate.
+
+The Spark UI can also be secured by using javax servlet filters. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view acls to make sure they are authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' control the behavior of the acls. Note that the person who started the application always has view access to the UI.
+
+For Spark on Yarn deployments, configuring `spark.authenticate` to true will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. The Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI.
+
+For other types of Spark deployments, the spark config `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI.
+
+IMPORTANT NOTE: The NettyBlockFetcherIterator is not secured so do not use netty for the shuffle is running with authentication on.
+
+See [Spark Configuration](configuration.html) for more details on the security configs.
+
+See org.apache.spark.SecurityManager for implementation details about security.
diff --git a/docs/spark-debugger.md b/docs/spark-debugger.md
index 11c51d5cde7c9..891c2bfa8943d 100644
--- a/docs/spark-debugger.md
+++ b/docs/spark-debugger.md
@@ -2,7 +2,7 @@
layout: global
title: The Spark Debugger
---
-**Summary:** The Spark debugger provides replay debugging for deterministic (logic) errors in Spark programs. It's currently in development, but you can try it out in the [arthur branch](https://github.com/apache/incubator-spark/tree/arthur).
+**Summary:** The Spark debugger provides replay debugging for deterministic (logic) errors in Spark programs. It's currently in development, but you can try it out in the [arthur branch](https://github.com/apache/spark/tree/arthur).
## Introduction
@@ -19,7 +19,7 @@ For deterministic errors, debugging a Spark program is now as easy as debugging
## Approach
-As your Spark program runs, the slaves report key events back to the master -- for example, RDD creations, RDD contents, and uncaught exceptions. (A full list of event types is in [EventLogging.scala](https://github.com/apache/incubator-spark/blob/arthur/core/src/main/scala/spark/EventLogging.scala).) The master logs those events, and you can load the event log into the debugger after your program is done running.
+As your Spark program runs, the slaves report key events back to the master -- for example, RDD creations, RDD contents, and uncaught exceptions. (A full list of event types is in [EventLogging.scala](https://github.com/apache/spark/blob/arthur/core/src/main/scala/spark/EventLogging.scala).) The master logs those events, and you can load the event log into the debugger after your program is done running.
_A note on nondeterminism:_ For fault recovery, Spark requires RDD transformations (for example, the function passed to `RDD.map`) to be deterministic. The Spark debugger also relies on this property, and it can also warn you if your transformation is nondeterministic. This works by checksumming the contents of each RDD and comparing the checksums from the original execution to the checksums after recomputing the RDD in the debugger.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 57e88581616a2..f9904d45013f6 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -58,11 +58,21 @@ do is as follows.
+First, we import the names of the Spark Streaming classes, and some implicit
+conversions from StreamingContext into our environment, to add useful methods to
+other classes we need (like DStream).
-First, we create a
-[StreamingContext](api/streaming/index.html#org.apache.spark.streaming.StreamingContext) object,
-which is the main entry point for all streaming
-functionality. Besides Spark's configuration, we specify that any DStream will be processed
+[StreamingContext](api/streaming/index.html#org.apache.spark.streaming.StreamingContext) is the
+main entry point for all streaming functionality.
+
+{% highlight scala %}
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.StreamingContext._
+{% endhighlight %}
+
+Then we create a
+[StreamingContext](api/streaming/index.html#org.apache.spark.streaming.StreamingContext) object.
+Besides Spark's configuration, we specify that any DStream will be processed
in 1 second batches.
{% highlight scala %}
@@ -98,7 +108,7 @@ val pairs = words.map(word => (word, 1))
val wordCounts = pairs.reduceByKey(_ + _)
// Print a few of the counts to the console
-wordCount.print()
+wordCounts.print()
{% endhighlight %}
The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word,
@@ -178,7 +188,7 @@ JavaPairDStream wordCounts = pairs.reduceByKey(
return i1 + i2;
}
});
-wordCount.print(); // Print a few of the counts to the console
+wordCounts.print(); // Print a few of the counts to the console
{% endhighlight %}
The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word,
@@ -262,6 +272,24 @@ Time: 1357008430000 ms
+If you plan to run the Scala code for Spark Streaming-based use cases in the Spark
+shell, you should start the shell with the SparkConfiguration pre-configured to
+discard old batches periodically:
+
+{% highlight bash %}
+$ SPARK_JAVA_OPTS=-Dspark.cleaner.ttl=10000 bin/spark-shell
+{% endhighlight %}
+
+... and create your StreamingContext by wrapping the existing interactive shell
+SparkContext object, `sc`:
+
+{% highlight scala %}
+val ssc = new StreamingContext(sc, Seconds(1))
+{% endhighlight %}
+
+When working with the shell, you may also need to send a `^D` to your netcat session
+to force the pipeline to print the word counts to the console at the sink.
+
***************************************************************************************************
# Basics
@@ -511,7 +539,7 @@ common ones are as follows.
updateStateByKey(func)
Return a new "state" DStream where the state for each key is updated by applying the
given function on the previous state of the key and the new values for the key. This can be
- used to maintain arbitrary state data for each ket.
+ used to maintain arbitrary state data for each key.
diff --git a/docs/tuning.md b/docs/tuning.md
index 6b010aed618a3..093df3187a789 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -44,7 +44,10 @@ This setting configures the serializer used for not only shuffling data between
nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom
registration requirement, but we recommend trying it in any network-intensive application.
-Finally, to register your classes with Kryo, create a public class that extends
+Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered
+in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library.
+
+To register your own custom classes with Kryo, create a public class that extends
[`org.apache.spark.serializer.KryoRegistrator`](api/core/index.html#org.apache.spark.serializer.KryoRegistrator) and set the
`spark.kryo.registrator` config property to point to it, as follows:
@@ -72,8 +75,8 @@ If your objects are large, you may also need to increase the `spark.kryoserializ
config property. The default is 2, but this value needs to be large enough to hold the *largest*
object you will serialize.
-Finally, if you don't register your classes, Kryo will still work, but it will have to store the
-full class name with each object, which is wasteful.
+Finally, if you don't register your custom classes, Kryo will still work, but it will have to store
+the full class name with each object, which is wasteful.
# Memory Tuning
@@ -160,8 +163,8 @@ their work directories), *not* on your driver program.
**Cache Size Tuning**
One important configuration parameter for GC is the amount of memory that should be used for caching RDDs.
-By default, Spark uses 66% of the configured executor memory (`spark.executor.memory` or `SPARK_MEM`) to
-cache RDDs. This means that 33% of memory is available for any objects created during task execution.
+By default, Spark uses 60% of the configured executor memory (`spark.executor.memory`) to
+cache RDDs. This means that 40% of memory is available for any objects created during task execution.
In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of
memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call
diff --git a/ec2/README b/ec2/README
index 433da37b4c37c..72434f24bf98d 100644
--- a/ec2/README
+++ b/ec2/README
@@ -1,4 +1,4 @@
This folder contains a script, spark-ec2, for launching Spark clusters on
Amazon EC2. Usage instructions are available online at:
-http://spark.incubator.apache.org/docs/latest/ec2-scripts.html
+http://spark.apache.org/docs/latest/ec2-scripts.html
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index b0512ca891ad6..d8840c94ac17c 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -73,7 +73,7 @@ def parse_args():
parser.add_option("-v", "--spark-version", default="0.9.0",
help="Version of Spark to use: 'X.Y.Z' or a specific git hash")
parser.add_option("--spark-git-repo",
- default="https://github.com/apache/incubator-spark",
+ default="https://github.com/apache/spark",
help="Github repo from which to checkout supplied commit hash")
parser.add_option("--hadoop-major-version", default="1",
help="Major version of Hadoop (default: 1)")
@@ -398,15 +398,13 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
if any((master_nodes, slave_nodes)):
print ("Found %d master(s), %d slaves" %
(len(master_nodes), len(slave_nodes)))
- if (master_nodes != [] and slave_nodes != []) or not die_on_error:
+ if master_nodes != [] or not die_on_error:
return (master_nodes, slave_nodes)
else:
if master_nodes == [] and slave_nodes != []:
- print "ERROR: Could not find master in group " + cluster_name + "-master"
- elif master_nodes != [] and slave_nodes == []:
- print "ERROR: Could not find slaves in group " + cluster_name + "-slaves"
+ print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master"
else:
- print "ERROR: Could not find any existing cluster"
+ print >> sys.stderr, "ERROR: Could not find any existing cluster"
sys.exit(1)
@@ -680,6 +678,9 @@ def real_main():
opts.zone = random.choice(conn.get_all_zones()).name
if action == "launch":
+ if opts.slaves <= 0:
+ print >> sys.stderr, "ERROR: You have to start at least 1 slave"
+ sys.exit(1)
if opts.resume:
(master_nodes, slave_nodes) = get_existing_cluster(
conn, opts, cluster_name)
diff --git a/examples/pom.xml b/examples/pom.xml
index 12a11821a4947..382a38d9400b9 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.0.0-incubating-SNAPSHOT
+ 1.0.0-SNAPSHOT../pom.xml
@@ -29,22 +29,21 @@
spark-examples_2.10jarSpark Project Examples
- http://spark.incubator.apache.org/
-
-
-
- apache-repo
- Apache Repository
- https://repository.apache.org/content/repositories/releases
-
- true
-
-
- false
-
-
-
+ http://spark.apache.org/
+
+
+
+ yarn-alpha
+
+
+ org.apache.avro
+ avro
+
+
+
+
@@ -169,14 +168,6 @@
org.apache.cassandra.depsavro
-
- org.sonatype.sisu.inject
- *
-
-
- org.xerial.snappy
- *
-
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
index d552c47b22231..6b49244ba459d 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
@@ -45,7 +45,7 @@ static class DataPoint implements Serializable {
double y;
}
- static class ParsePoint extends Function {
+ static class ParsePoint implements Function {
private static final Pattern SPACE = Pattern.compile(" ");
@Override
@@ -60,7 +60,7 @@ public DataPoint call(String line) {
}
}
- static class VectorSum extends Function2 {
+ static class VectorSum implements Function2 {
@Override
public double[] call(double[] a, double[] b) {
double[] result = new double[D];
@@ -71,7 +71,7 @@ public double[] call(double[] a, double[] b) {
}
}
- static class ComputeGradient extends Function {
+ static class ComputeGradient implements Function {
private final double[] weights;
ComputeGradient(double[] weights) {
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
index 0dc879275a22a..2d797279d5bcc 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
@@ -98,7 +98,7 @@ public Vector call(String line) {
double tempDist;
do {
// allocate each vector to closest centroid
- JavaPairRDD closest = data.map(
+ JavaPairRDD closest = data.mapToPair(
new PairFunction() {
@Override
public Tuple2 call(Vector vector) {
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
index 9eb1cadd71d22..617e4a6d045e0 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
@@ -108,7 +108,7 @@ public static void main(String[] args) {
JavaRDD dataSet = (args.length == 2) ? jsc.textFile(args[1]) : jsc.parallelize(exampleApacheLogs);
- JavaPairRDD, Stats> extracted = dataSet.map(new PairFunction, Stats>() {
+ JavaPairRDD, Stats> extracted = dataSet.mapToPair(new PairFunction, Stats>() {
@Override
public Tuple2, Stats> call(String s) {
return new Tuple2, Stats>(extractKey(s), extractStats(s));
@@ -124,7 +124,7 @@ public Stats call(Stats stats, Stats stats2) {
List, Stats>> output = counts.collect();
for (Tuple2,?> t : output) {
- System.out.println(t._1 + "\t" + t._2);
+ System.out.println(t._1() + "\t" + t._2());
}
System.exit(0);
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
index a84245b0c7449..eb70fb547564c 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
@@ -42,7 +42,7 @@
public final class JavaPageRank {
private static final Pattern SPACES = Pattern.compile("\\s+");
- private static class Sum extends Function2 {
+ private static class Sum implements Function2 {
@Override
public Double call(Double a, Double b) {
return a + b;
@@ -66,7 +66,7 @@ public static void main(String[] args) throws Exception {
JavaRDD lines = ctx.textFile(args[1], 1);
// Loads all URLs from input file and initialize their neighbors.
- JavaPairRDD> links = lines.map(new PairFunction() {
+ JavaPairRDD> links = lines.mapToPair(new PairFunction() {
@Override
public Tuple2 call(String s) {
String[] parts = SPACES.split(s);
@@ -86,12 +86,12 @@ public Double call(List rs) {
for (int current = 0; current < Integer.parseInt(args[2]); current++) {
// Calculates URL contributions to the rank of other URLs.
JavaPairRDD contribs = links.join(ranks).values()
- .flatMap(new PairFlatMapFunction, Double>, String, Double>() {
+ .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() {
@Override
public Iterable> call(Tuple2, Double> s) {
List> results = new ArrayList>();
- for (String n : s._1) {
- results.add(new Tuple2(n, s._2 / s._1.size()));
+ for (String n : s._1()) {
+ results.add(new Tuple2(n, s._2() / s._1().size()));
}
return results;
}
@@ -109,7 +109,7 @@ public Double call(Double sum) {
// Collects all URL ranks and dump them to console.
List> output = ranks.collect();
for (Tuple2,?> tuple : output) {
- System.out.println(tuple._1 + " has rank: " + tuple._2 + ".");
+ System.out.println(tuple._1() + " has rank: " + tuple._2() + ".");
}
System.exit(0);
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java
index 2ceb0fd94ba65..6cfe25c80ecc6 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java
@@ -50,7 +50,7 @@ static List> generateGraph() {
return new ArrayList