diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000..2b65f6fe3cc80
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.bat text eol=crlf
+*.cmd text eol=crlf
diff --git a/.rat-excludes b/.rat-excludes
index b14ad53720f32..d8bee1f8e49c9 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -1,5 +1,6 @@
target
.gitignore
+.gitattributes
.project
.classpath
.mima-excludes
@@ -43,11 +44,13 @@ SparkImports.scala
SparkJLineCompletion.scala
SparkJLineReader.scala
SparkMemberHandlers.scala
+SparkReplReporter.scala
sbt
sbt-launch-lib.bash
plugins.sbt
work
.*\.q
+.*\.qv
golden
test.out/*
.*iml
diff --git a/LICENSE b/LICENSE
index 0517dfb0ab53d..4f2f0e7a7006a 100644
--- a/LICENSE
+++ b/LICENSE
@@ -712,18 +712,6 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-========================================================================
-For colt:
-========================================================================
-
-Copyright (c) 1999 CERN - European Organization for Nuclear Research.
-Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose is hereby granted without fee, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation. CERN makes no representations about the suitability of this software for any purpose. It is provided "as is" without expressed or implied warranty.
-
-Packages hep.aida.*
-
-Written by Pavel Binko, Dino Ferrero Merlino, Wolfgang Hoschek, Tony Johnson, Andreas Pfeiffer, and others. Check the FreeHEP home page for more info. Permission to use and/or redistribute this work is granted under the terms of the LGPL License, with the exception that any usage related to military applications is expressly forbidden. The software and documentation made available under the terms of this license are provided with no warranty.
-
-
========================================================================
For SnapTree:
========================================================================
@@ -766,7 +754,7 @@ SUCH DAMAGE.
========================================================================
-For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java):
+For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java):
========================================================================
Copyright (C) 2008 The Android Open Source Project
@@ -783,6 +771,25 @@ See the License for the specific language governing permissions and
limitations under the License.
+========================================================================
+For LimitedInputStream
+ (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java):
+========================================================================
+Copyright (C) 2007 The Guava Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/README.md b/README.md
index 8dd8b70696aa2..8d57d50da96c9 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,8 @@ and Spark Streaming for stream processing.
## Online Documentation
You can find the latest Spark documentation, including a programming
-guide, on the [project web page](http://spark.apache.org/documentation.html).
+guide, on the [project web page](http://spark.apache.org/documentation.html)
+and [project wiki](https://cwiki.apache.org/confluence/display/SPARK).
This README file only contains basic setup instructions.
## Building Spark
@@ -25,7 +26,7 @@ To build Spark and its example programs, run:
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
-["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html).
+["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html).
## Interactive Scala Shell
@@ -84,7 +85,7 @@ storage systems. Because the protocols have changed in different versions of
Hadoop, you must build Spark against the same version that your cluster runs.
Please refer to the build documentation at
-["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version)
+["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-with-maven.html#specifying-the-hadoop-version)
for detailed guidance on building for a particular distribution of Hadoop, including
building for particular Hive and Hive Thriftserver distributions. See also
["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html)
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 31a01e4d8e1de..4e2b773e7d2f3 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -66,22 +66,22 @@
org.apache.spark
- spark-repl_${scala.binary.version}
+ spark-streaming_${scala.binary.version}${project.version}org.apache.spark
- spark-streaming_${scala.binary.version}
+ spark-graphx_${scala.binary.version}${project.version}org.apache.spark
- spark-graphx_${scala.binary.version}
+ spark-sql_${scala.binary.version}${project.version}org.apache.spark
- spark-sql_${scala.binary.version}
+ spark-repl_${scala.binary.version}${project.version}
@@ -197,6 +197,11 @@
spark-hive_${scala.binary.version}${project.version}
+
+
+
+ hive-thriftserver
+ org.apache.sparkspark-hive-thriftserver_${scala.binary.version}
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 93db0d5efda5f..0327ffa402671 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index 9b9e40321ea93..a4c099fb45b14 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -1,117 +1,117 @@
-@echo off
-
-rem
-rem Licensed to the Apache Software Foundation (ASF) under one or more
-rem contributor license agreements. See the NOTICE file distributed with
-rem this work for additional information regarding copyright ownership.
-rem The ASF licenses this file to You under the Apache License, Version 2.0
-rem (the "License"); you may not use this file except in compliance with
-rem the License. You may obtain a copy of the License at
-rem
-rem http://www.apache.org/licenses/LICENSE-2.0
-rem
-rem Unless required by applicable law or agreed to in writing, software
-rem distributed under the License is distributed on an "AS IS" BASIS,
-rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-rem See the License for the specific language governing permissions and
-rem limitations under the License.
-rem
-
-rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
-rem script and the ExecutorRunner in standalone cluster mode.
-
-rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
-rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
-rem need to set it here because we use !datanucleus_jars! below.
-if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
-setlocal enabledelayedexpansion
-:skip_delayed_expansion
-
-set SCALA_VERSION=2.10
-
-rem Figure out where the Spark framework is installed
-set FWDIR=%~dp0..\
-
-rem Load environment variables from conf\spark-env.cmd, if it exists
-if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
-
-rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
-
-if "x%SPARK_CONF_DIR%"!="x" (
- set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
-) else (
- set CLASSPATH=%CLASSPATH%;%FWDIR%conf
-)
-
-if exist "%FWDIR%RELEASE" (
- for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-) else (
- for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-)
-
-set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
-
-rem When Hive support is needed, Datanucleus jars must be included on the classpath.
-rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
-rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
-rem built with Hive, so look for them there.
-if exist "%FWDIR%RELEASE" (
- set datanucleus_dir=%FWDIR%lib
-) else (
- set datanucleus_dir=%FWDIR%lib_managed\jars
-)
-set "datanucleus_jars="
-for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
- set datanucleus_jars=!datanucleus_jars!;%%d
-)
-set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
-
-set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
-
-set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
-
-if "x%SPARK_TESTING%"=="x1" (
- rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
- rem so that local compilation takes precedence over assembled jar
- set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
-)
-
-rem Add hadoop conf dir - else FileSystem.*, etc fail
-rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
-rem the configurtion files.
-if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
- set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
-:no_hadoop_conf_dir
-
-if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
- set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
-:no_yarn_conf_dir
-
-rem A bit of a hack to allow calling this script within run2.cmd without seeing output
-if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
-
-echo %CLASSPATH%
-
-:exit
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+rem script and the ExecutorRunner in standalone cluster mode.
+
+rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
+rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
+rem need to set it here because we use !datanucleus_jars! below.
+if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
+setlocal enabledelayedexpansion
+:skip_delayed_expansion
+
+set SCALA_VERSION=2.10
+
+rem Figure out where the Spark framework is installed
+set FWDIR=%~dp0..\
+
+rem Load environment variables from conf\spark-env.cmd, if it exists
+if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+
+rem Build up classpath
+set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
+
+if not "x%SPARK_CONF_DIR%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
+) else (
+ set CLASSPATH=%CLASSPATH%;%FWDIR%conf
+)
+
+if exist "%FWDIR%RELEASE" (
+ for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+) else (
+ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+)
+
+set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
+
+rem When Hive support is needed, Datanucleus jars must be included on the classpath.
+rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
+rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
+rem built with Hive, so look for them there.
+if exist "%FWDIR%RELEASE" (
+ set datanucleus_dir=%FWDIR%lib
+) else (
+ set datanucleus_dir=%FWDIR%lib_managed\jars
+)
+set "datanucleus_jars="
+for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
+ set datanucleus_jars=!datanucleus_jars!;%%d
+)
+set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
+
+set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
+
+set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
+
+if "x%SPARK_TESTING%"=="x1" (
+ rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
+ rem so that local compilation takes precedence over assembled jar
+ set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
+)
+
+rem Add hadoop conf dir - else FileSystem.*, etc fail
+rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+rem the configurtion files.
+if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
+ set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
+:no_hadoop_conf_dir
+
+if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
+ set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
+:no_yarn_conf_dir
+
+rem A bit of a hack to allow calling this script within run2.cmd without seeing output
+if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
+
+echo %CLASSPATH%
+
+:exit
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 905bbaf99b374..298641f2684de 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -20,8 +20,6 @@
# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
# script and the ExecutorRunner in standalone cluster mode.
-SCALA_VERSION=2.10
-
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
@@ -36,7 +34,7 @@ else
CLASSPATH="$CLASSPATH:$FWDIR/conf"
fi
-ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION"
+ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION"
if [ -n "$JAVA_HOME" ]; then
JAR_CMD="$JAVA_HOME/bin/jar"
@@ -48,19 +46,19 @@ fi
if [ -n "$SPARK_PREPEND_CLASSES" ]; then
echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\
"classes ahead of assembly." >&2
- CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*"
- CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes"
fi
# Use spark-assembly jar from either RELEASE or assembly directory
@@ -123,15 +121,15 @@ fi
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
- CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes"
fi
# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail !
diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh
index 6d4231b204595..356b3d49b2ffe 100644
--- a/bin/load-spark-env.sh
+++ b/bin/load-spark-env.sh
@@ -36,3 +36,23 @@ if [ -z "$SPARK_ENV_LOADED" ]; then
set +a
fi
fi
+
+# Setting SPARK_SCALA_VERSION if not already set.
+
+if [ -z "$SPARK_SCALA_VERSION" ]; then
+
+ ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11"
+ ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10"
+
+ if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then
+ echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2
+ echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2
+ exit 1
+ fi
+
+ if [ -d "$ASSEMBLY_DIR2" ]; then
+ export SPARK_SCALA_VERSION="2.11"
+ else
+ export SPARK_SCALA_VERSION="2.10"
+ fi
+fi
diff --git a/bin/pyspark b/bin/pyspark
index 6655725ef8e8e..0b4f695dd06dd 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -25,7 +25,7 @@ export SPARK_HOME="$FWDIR"
source "$FWDIR/bin/utils.sh"
-SCALA_VERSION=2.10
+source "$FWDIR"/bin/load-spark-env.sh
function usage() {
echo "Usage: ./bin/pyspark [options]" 1>&2
@@ -40,7 +40,7 @@ fi
# Exit if the user hasn't compiled Spark
if [ ! -f "$FWDIR/RELEASE" ]; then
# Exit if the user hasn't compiled Spark
- ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
+ ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
if [[ $? != 0 ]]; then
echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2
echo "You need to build Spark before running this program" 1>&2
@@ -48,24 +48,47 @@ if [ ! -f "$FWDIR/RELEASE" ]; then
fi
fi
-. "$FWDIR"/bin/load-spark-env.sh
+# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython`
+# executable, while the worker would still be launched using PYSPARK_PYTHON.
+#
+# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added
+# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver.
+# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set
+# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
+# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
+# and executor Python executables.
+#
+# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables.
+
+# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set:
+if hash python2.7 2>/dev/null; then
+ # Attempt to use Python 2.7, if installed:
+ DEFAULT_PYTHON="python2.7"
+else
+ DEFAULT_PYTHON="python"
+fi
-# Figure out which Python executable to use
+# Determine the Python executable to use for the driver:
+if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then
+ # If IPython options are specified, assume user wants to run IPython
+ # (for backwards-compatibility)
+ PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS"
+ PYSPARK_DRIVER_PYTHON="ipython"
+elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then
+ PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}"
+fi
+
+# Determine the Python executable to use for the executors:
if [[ -z "$PYSPARK_PYTHON" ]]; then
- if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then
- # for backward compatibility
- PYSPARK_PYTHON="ipython"
+ if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then
+ echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2
+ exit 1
else
- PYSPARK_PYTHON="python"
+ PYSPARK_PYTHON="$DEFAULT_PYTHON"
fi
fi
export PYSPARK_PYTHON
-if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then
- # for backward compatibility
- PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS"
-fi
-
# Add the PySpark classes to the Python path:
export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH"
export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
@@ -93,9 +116,9 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
if [[ -n "$PYSPARK_DOC_TEST" ]]; then
- exec "$PYSPARK_PYTHON" -m doctest $1
+ exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1
else
- exec "$PYSPARK_PYTHON" $1
+ exec "$PYSPARK_DRIVER_PYTHON" $1
fi
exit
fi
@@ -109,7 +132,5 @@ if [[ "$1" =~ \.py$ ]]; then
gatherSparkSubmitOpts "$@"
exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}"
else
- # PySpark shell requires special handling downstream
- export PYSPARK_SHELL=1
- exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS
+ exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index a0e66abcc26c9..a542ec80b49d6 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -59,7 +59,11 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do (
)
if [%PYTHON_FILE%] == [] (
- %PYSPARK_PYTHON%
+ if [%IPYTHON%] == [1] (
+ ipython %IPYTHON_OPTS%
+ ) else (
+ %PYSPARK_PYTHON%
+ )
) else (
echo.
echo WARNING: Running python applications through ./bin/pyspark.cmd is deprecated as of Spark 1.0.
diff --git a/bin/run-example b/bin/run-example
index 34dd71c71880e..3d932509426fc 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -17,12 +17,12 @@
# limitations under the License.
#
-SCALA_VERSION=2.10
-
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
export SPARK_HOME="$FWDIR"
EXAMPLES_DIR="$FWDIR"/examples
+. "$FWDIR"/bin/load-spark-env.sh
+
if [ -n "$1" ]; then
EXAMPLE_CLASS="$1"
shift
@@ -36,8 +36,8 @@ fi
if [ -f "$FWDIR/RELEASE" ]; then
export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`"
-elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then
- export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`"
+elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then
+ export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`"
fi
if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then
diff --git a/bin/spark-class b/bin/spark-class
index e8201c18d52de..0d58d95c1aee3 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -24,8 +24,6 @@ case "`uname`" in
CYGWIN*) cygwin=true;;
esac
-SCALA_VERSION=2.10
-
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
@@ -81,7 +79,11 @@ case "$1" in
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS"
OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM}
if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH"
+ if [[ $OSTYPE == darwin* ]]; then
+ export DYLD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$DYLD_LIBRARY_PATH"
+ else
+ export LD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$LD_LIBRARY_PATH"
+ fi
fi
if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then
OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY"
@@ -105,7 +107,7 @@ else
exit 1
fi
fi
-JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q')
+JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q')
# Set JAVA_OPTS to be able to load native libraries and to set heap size
if [ "$JAVA_VERSION" -ge 18 ]; then
@@ -124,9 +126,9 @@ fi
TOOLS_DIR="$FWDIR"/tools
SPARK_TOOLS_JAR=""
-if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then
+if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then
# Use the JAR from the SBT build
- export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`"
+ export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`"
fi
if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then
# Use the JAR from the Maven build
@@ -145,7 +147,7 @@ fi
if [[ "$1" =~ org.apache.spark.tools.* ]]; then
if test -z "$SPARK_TOOLS_JAR"; then
- echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2
+ echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2
echo "You need to build Spark before running $1." 1>&2
exit 1
fi
diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd
index 2ee60b4e2a2b3..8f90ba5a0b3b8 100755
--- a/bin/spark-shell.cmd
+++ b/bin/spark-shell.cmd
@@ -17,6 +17,7 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-set SPARK_HOME=%~dp0..
+rem This is the entry point for running Spark shell. To avoid polluting the
+rem environment, it just launches a new cmd to do the real work.
-cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell
+cmd /V /E /C %~dp0spark-shell2.cmd %*
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
new file mode 100644
index 0000000000000..2ee60b4e2a2b3
--- /dev/null
+++ b/bin/spark-shell2.cmd
@@ -0,0 +1,22 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+set SPARK_HOME=%~dp0..
+
+cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell
diff --git a/bin/spark-submit b/bin/spark-submit
index c557311b4b20e..f92d90c3a66b0 100755
--- a/bin/spark-submit
+++ b/bin/spark-submit
@@ -22,6 +22,9 @@
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
ORIG_ARGS=("$@")
+# Set COLUMNS for progress bar
+export COLUMNS=`tput cols`
+
while (($#)); do
if [ "$1" = "--deploy-mode" ]; then
SPARK_SUBMIT_DEPLOY_MODE=$2
diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd
index cf6046d1547ad..8f3b84c7b971d 100644
--- a/bin/spark-submit.cmd
+++ b/bin/spark-submit.cmd
@@ -17,52 +17,7 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala!
+rem This is the entry point for running Spark submit. To avoid polluting the
+rem environment, it just launches a new cmd to do the real work.
-set SPARK_HOME=%~dp0..
-set ORIG_ARGS=%*
-
-rem Reset the values of all variables used
-set SPARK_SUBMIT_DEPLOY_MODE=client
-set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
-set SPARK_SUBMIT_DRIVER_MEMORY=
-set SPARK_SUBMIT_LIBRARY_PATH=
-set SPARK_SUBMIT_CLASSPATH=
-set SPARK_SUBMIT_OPTS=
-set SPARK_SUBMIT_BOOTSTRAP_DRIVER=
-
-:loop
-if [%1] == [] goto continue
- if [%1] == [--deploy-mode] (
- set SPARK_SUBMIT_DEPLOY_MODE=%2
- ) else if [%1] == [--properties-file] (
- set SPARK_SUBMIT_PROPERTIES_FILE=%2
- ) else if [%1] == [--driver-memory] (
- set SPARK_SUBMIT_DRIVER_MEMORY=%2
- ) else if [%1] == [--driver-library-path] (
- set SPARK_SUBMIT_LIBRARY_PATH=%2
- ) else if [%1] == [--driver-class-path] (
- set SPARK_SUBMIT_CLASSPATH=%2
- ) else if [%1] == [--driver-java-options] (
- set SPARK_SUBMIT_OPTS=%2
- )
- shift
-goto loop
-:continue
-
-rem For client mode, the driver will be launched in the same JVM that launches
-rem SparkSubmit, so we may need to read the properties file for any extra class
-rem paths, library paths, java options and memory early on. Otherwise, it will
-rem be too late by the time the driver JVM has started.
-
-if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] (
- if exist %SPARK_SUBMIT_PROPERTIES_FILE% (
- rem Parse the properties file only if the special configs exist
- for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^
- %SPARK_SUBMIT_PROPERTIES_FILE%') do (
- set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1
- )
- )
-)
-
-cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS%
+cmd /V /E /C %~dp0spark-submit2.cmd %*
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
new file mode 100644
index 0000000000000..cf6046d1547ad
--- /dev/null
+++ b/bin/spark-submit2.cmd
@@ -0,0 +1,68 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala!
+
+set SPARK_HOME=%~dp0..
+set ORIG_ARGS=%*
+
+rem Reset the values of all variables used
+set SPARK_SUBMIT_DEPLOY_MODE=client
+set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
+set SPARK_SUBMIT_DRIVER_MEMORY=
+set SPARK_SUBMIT_LIBRARY_PATH=
+set SPARK_SUBMIT_CLASSPATH=
+set SPARK_SUBMIT_OPTS=
+set SPARK_SUBMIT_BOOTSTRAP_DRIVER=
+
+:loop
+if [%1] == [] goto continue
+ if [%1] == [--deploy-mode] (
+ set SPARK_SUBMIT_DEPLOY_MODE=%2
+ ) else if [%1] == [--properties-file] (
+ set SPARK_SUBMIT_PROPERTIES_FILE=%2
+ ) else if [%1] == [--driver-memory] (
+ set SPARK_SUBMIT_DRIVER_MEMORY=%2
+ ) else if [%1] == [--driver-library-path] (
+ set SPARK_SUBMIT_LIBRARY_PATH=%2
+ ) else if [%1] == [--driver-class-path] (
+ set SPARK_SUBMIT_CLASSPATH=%2
+ ) else if [%1] == [--driver-java-options] (
+ set SPARK_SUBMIT_OPTS=%2
+ )
+ shift
+goto loop
+:continue
+
+rem For client mode, the driver will be launched in the same JVM that launches
+rem SparkSubmit, so we may need to read the properties file for any extra class
+rem paths, library paths, java options and memory early on. Otherwise, it will
+rem be too late by the time the driver JVM has started.
+
+if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] (
+ if exist %SPARK_SUBMIT_PROPERTIES_FILE% (
+ rem Parse the properties file only if the special configs exist
+ for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^
+ %SPARK_SUBMIT_PROPERTIES_FILE%') do (
+ set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1
+ )
+ )
+)
+
+cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS%
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index f8ffbf64278fb..0886b0276fb90 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -28,7 +28,7 @@
# - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job.
# - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job.
-# Options for the daemons used in the standalone deploy mode:
+# Options for the daemons used in the standalone deploy mode
# - SPARK_MASTER_IP, to bind the master to a different IP address or hostname
# - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master
# - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y")
@@ -41,3 +41,10 @@
# - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y")
# - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y")
# - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers
+
+# Generic options for the daemons used in the standalone deploy mode
+# - SPARK_CONF_DIR Alternate conf dir. (Default: ${SPARK_HOME}/conf)
+# - SPARK_LOG_DIR Where log files are stored. (Default: ${SPARK_HOME}/logs)
+# - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp)
+# - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER)
+# - SPARK_NICENESS The scheduling priority for daemons. (Default: 0)
diff --git a/core/pom.xml b/core/pom.xml
index a5a178079bc57..1feb00b3a7fb8 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -34,6 +34,34 @@
Spark Project Corehttp://spark.apache.org/
+
+ com.twitter
+ chill_${scala.binary.version}
+
+
+ org.ow2.asm
+ asm
+
+
+ org.ow2.asm
+ asm-commons
+
+
+
+
+ com.twitter
+ chill-java
+
+
+ org.ow2.asm
+ asm
+
+
+ org.ow2.asm
+ asm-commons
+
+
+ org.apache.hadoophadoop-client
@@ -44,6 +72,16 @@
+
+ org.apache.spark
+ spark-network-common_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-network-shuffle_${scala.binary.version}
+ ${project.version}
+ net.java.dev.jets3tjets3t
@@ -85,8 +123,6 @@
org.apache.commonscommons-math3
- 3.3
- testcom.google.code.findbugs
@@ -125,12 +161,8 @@
lz4
- com.twitter
- chill_${scala.binary.version}
-
-
- com.twitter
- chill-java
+ org.roaringbitmap
+ RoaringBitmapcommons-net
@@ -158,10 +190,6 @@
json4s-jackson_${scala.binary.version}3.2.10
-
- colt
- colt
- org.apache.mesosmesos
@@ -243,6 +271,11 @@
+
+ org.seleniumhq.selenium
+ selenium-java
+ test
+ org.scalatestscalatest_${scala.binary.version}
@@ -296,14 +329,16 @@
org.scalatestscalatest-maven-plugin
-
-
- ${basedir}/..
- 1
- ${spark.classpath}
-
-
+
+
+ test
+
+ test
+
+
+
+
org.apache.maven.plugins
@@ -411,4 +446,5 @@
+
diff --git a/core/src/main/java/org/apache/spark/JobExecutionStatus.java b/core/src/main/java/org/apache/spark/JobExecutionStatus.java
new file mode 100644
index 0000000000000..6e161313702bb
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/JobExecutionStatus.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+public enum JobExecutionStatus {
+ RUNNING,
+ SUCCEEDED,
+ FAILED,
+ UNKNOWN
+}
diff --git a/core/src/main/java/org/apache/spark/SparkJobInfo.java b/core/src/main/java/org/apache/spark/SparkJobInfo.java
new file mode 100644
index 0000000000000..4e3c983b1170a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkJobInfo.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+/**
+ * Exposes information about Spark Jobs.
+ *
+ * This interface is not designed to be implemented outside of Spark. We may add additional methods
+ * which may break binary compatibility with outside implementations.
+ */
+public interface SparkJobInfo {
+ int jobId();
+ int[] stageIds();
+ JobExecutionStatus status();
+}
diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java
new file mode 100644
index 0000000000000..fd74321093658
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+/**
+ * Exposes information about Spark Stages.
+ *
+ * This interface is not designed to be implemented outside of Spark. We may add additional methods
+ * which may break binary compatibility with outside implementations.
+ */
+public interface SparkStageInfo {
+ int stageId();
+ int currentAttemptId();
+ long submissionTime();
+ String name();
+ int numTasks();
+ int numActiveTasks();
+ int numCompletedTasks();
+ int numFailedTasks();
+}
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
index 4e6d708af0ea7..0d6973203eba1 100644
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -18,252 +18,89 @@
package org.apache.spark;
import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import scala.Function0;
import scala.Function1;
import scala.Unit;
-import scala.collection.JavaConversions;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.util.TaskCompletionListener;
-import org.apache.spark.util.TaskCompletionListenerException;
/**
-* :: DeveloperApi ::
-* Contextual information about a task which can be read or mutated during execution.
-*/
-@DeveloperApi
-public class TaskContext implements Serializable {
-
- private int stageId;
- private int partitionId;
- private long attemptId;
- private boolean runningLocally;
- private TaskMetrics taskMetrics;
-
- /**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,
- TaskMetrics taskMetrics) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = runningLocally;
- this.stageId = stageId;
- this.taskMetrics = taskMetrics;
- }
-
- /**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- */
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = runningLocally;
- this.stageId = stageId;
- this.taskMetrics = TaskMetrics.empty();
- }
-
+ * Contextual information about a task which can be read or mutated during
+ * execution. To access the TaskContext for a running task use
+ * TaskContext.get().
+ */
+public abstract class TaskContext implements Serializable {
/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
+ * Return the currently active TaskContext. This can be called inside of
+ * user functions to access contextual information about running tasks.
*/
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = false;
- this.stageId = stageId;
- this.taskMetrics = TaskMetrics.empty();
+ public static TaskContext get() {
+ return taskContext.get();
}
private static ThreadLocal taskContext =
new ThreadLocal();
- /**
- * :: Internal API ::
- * This is spark internal API, not intended to be called from user programs.
- */
- public static void setTaskContext(TaskContext tc) {
+ static void setTaskContext(TaskContext tc) {
taskContext.set(tc);
}
- public static TaskContext get() {
- return taskContext.get();
- }
-
- /** :: Internal API :: */
- public static void unset() {
+ static void unset() {
taskContext.remove();
}
- // List of callback functions to execute when the task completes.
- private transient List onCompleteCallbacks =
- new ArrayList();
-
- // Whether the corresponding task has been killed.
- private volatile boolean interrupted = false;
-
- // Whether the task has completed.
- private volatile boolean completed = false;
-
/**
- * Checks whether the task has completed.
+ * Whether the task has completed.
*/
- public boolean isCompleted() {
- return completed;
- }
+ public abstract boolean isCompleted();
/**
- * Checks whether the task has been killed.
+ * Whether the task has been killed.
*/
- public boolean isInterrupted() {
- return interrupted;
- }
+ public abstract boolean isInterrupted();
+
+ /** @deprecated: use isRunningLocally() */
+ @Deprecated
+ public abstract boolean runningLocally();
+
+ public abstract boolean isRunningLocally();
/**
* Add a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
- *
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
- public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
- onCompleteCallbacks.add(listener);
- return this;
- }
+ public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener);
/**
* Add a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situations - success, failure, or cancellation.
- *
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
- public TaskContext addTaskCompletionListener(final Function1 f) {
- onCompleteCallbacks.add(new TaskCompletionListener() {
- @Override
- public void onTaskCompletion(TaskContext context) {
- f.apply(context);
- }
- });
- return this;
- }
+ public abstract TaskContext addTaskCompletionListener(final Function1 f);
/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
*
- * Deprecated: use addTaskCompletionListener
- *
+ * @deprecated: use addTaskCompletionListener
+ *
* @param f Callback function.
*/
@Deprecated
- public void addOnCompleteCallback(final Function0 f) {
- onCompleteCallbacks.add(new TaskCompletionListener() {
- @Override
- public void onTaskCompletion(TaskContext context) {
- f.apply();
- }
- });
- }
-
- /**
- * ::Internal API::
- * Marks the task as completed and triggers the listeners.
- */
- public void markTaskCompleted() throws TaskCompletionListenerException {
- completed = true;
- List errorMsgs = new ArrayList(2);
- // Process complete callbacks in the reverse order of registration
- List revlist =
- new ArrayList(onCompleteCallbacks);
- Collections.reverse(revlist);
- for (TaskCompletionListener tcl: revlist) {
- try {
- tcl.onTaskCompletion(this);
- } catch (Throwable e) {
- errorMsgs.add(e.getMessage());
- }
- }
-
- if (!errorMsgs.isEmpty()) {
- throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
- }
- }
-
- /**
- * ::Internal API::
- * Marks the task for interruption, i.e. cancellation.
- */
- public void markInterrupted() {
- interrupted = true;
- }
-
- @Deprecated
- /** Deprecated: use getStageId() */
- public int stageId() {
- return stageId;
- }
-
- @Deprecated
- /** Deprecated: use getPartitionId() */
- public int partitionId() {
- return partitionId;
- }
-
- @Deprecated
- /** Deprecated: use getAttemptId() */
- public long attemptId() {
- return attemptId;
- }
-
- @Deprecated
- /** Deprecated: use isRunningLocally() */
- public boolean runningLocally() {
- return runningLocally;
- }
-
- public boolean isRunningLocally() {
- return runningLocally;
- }
+ public abstract void addOnCompleteCallback(final Function0 f);
- public int getStageId() {
- return stageId;
- }
+ public abstract int stageId();
- public int getPartitionId() {
- return partitionId;
- }
+ public abstract int partitionId();
- public long getAttemptId() {
- return attemptId;
- }
+ public abstract long attemptId();
- /** ::Internal API:: */
- public TaskMetrics taskMetrics() {
- return taskMetrics;
- }
+ /** ::DeveloperApi:: */
+ @DeveloperApi
+ public abstract TaskMetrics taskMetrics();
}
diff --git a/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java
new file mode 100644
index 0000000000000..0ad189633e427
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java;
+
+
+import java.util.List;
+import java.util.concurrent.Future;
+
+public interface JavaFutureAction extends Future {
+
+ /**
+ * Returns the job IDs run by the underlying async operation.
+ *
+ * This returns the current snapshot of the job list. Certain operations may run multiple
+ * jobs, so multiple calls to this method may return different lists.
+ */
+ List jobIds();
+}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
index abd9bcc07ac61..99bf240a17225 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
@@ -22,7 +22,8 @@
import scala.Tuple2;
/**
- * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs.
+ * A function that returns key-value pairs (Tuple2<K, V>), and can be used to
+ * construct PairRDDs.
*/
public interface PairFunction extends Serializable {
public Tuple2 call(T t) throws Exception;
diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala
index 7f91de653a64a..0f9bac7164162 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/package.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala
@@ -22,4 +22,4 @@ package org.apache.spark.api.java
* these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's
* Java programming guide for more details.
*/
-package object function
\ No newline at end of file
+package object function
diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/Sorter.java
deleted file mode 100644
index 64ad18c0e463a..0000000000000
--- a/core/src/main/java/org/apache/spark/util/collection/Sorter.java
+++ /dev/null
@@ -1,915 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection;
-
-import java.util.Comparator;
-
-/**
- * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort."
- * See the method comment on sort() for more details.
- *
- * This has been kept in Java with the original style in order to match very closely with the
- * Anroid source code, and thus be easy to verify correctness.
- *
- * The purpose of the port is to generalize the interface to the sort to accept input data formats
- * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
- * uses this to sort an Array with alternating elements of the form [key, value, key, value].
- * This generalization comes with minimal overhead -- see SortDataFormat for more information.
- */
-class Sorter {
-
- /**
- * This is the minimum sized sequence that will be merged. Shorter
- * sequences will be lengthened by calling binarySort. If the entire
- * array is less than this length, no merges will be performed.
- *
- * This constant should be a power of two. It was 64 in Tim Peter's C
- * implementation, but 32 was empirically determined to work better in
- * this implementation. In the unlikely event that you set this constant
- * to be a number that's not a power of two, you'll need to change the
- * minRunLength computation.
- *
- * If you decrease this constant, you must change the stackLen
- * computation in the TimSort constructor, or you risk an
- * ArrayOutOfBounds exception. See listsort.txt for a discussion
- * of the minimum stack length required as a function of the length
- * of the array being sorted and the minimum merge sequence length.
- */
- private static final int MIN_MERGE = 32;
-
- private final SortDataFormat s;
-
- public Sorter(SortDataFormat sortDataFormat) {
- this.s = sortDataFormat;
- }
-
- /**
- * A stable, adaptive, iterative mergesort that requires far fewer than
- * n lg(n) comparisons when running on partially sorted arrays, while
- * offering performance comparable to a traditional mergesort when run
- * on random arrays. Like all proper mergesorts, this sort is stable and
- * runs O(n log n) time (worst case). In the worst case, this sort requires
- * temporary storage space for n/2 object references; in the best case,
- * it requires only a small constant amount of space.
- *
- * This implementation was adapted from Tim Peters's list sort for
- * Python, which is described in detail here:
- *
- * http://svn.python.org/projects/python/trunk/Objects/listsort.txt
- *
- * Tim's C code may be found here:
- *
- * http://svn.python.org/projects/python/trunk/Objects/listobject.c
- *
- * The underlying techniques are described in this paper (and may have
- * even earlier origins):
- *
- * "Optimistic Sorting and Information Theoretic Complexity"
- * Peter McIlroy
- * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms),
- * pp 467-474, Austin, Texas, 25-27 January 1993.
- *
- * While the API to this class consists solely of static methods, it is
- * (privately) instantiable; a TimSort instance holds the state of an ongoing
- * sort, assuming the input array is large enough to warrant the full-blown
- * TimSort. Small arrays are sorted in place, using a binary insertion sort.
- *
- * @author Josh Bloch
- */
- void sort(Buffer a, int lo, int hi, Comparator super K> c) {
- assert c != null;
-
- int nRemaining = hi - lo;
- if (nRemaining < 2)
- return; // Arrays of size 0 and 1 are always sorted
-
- // If array is small, do a "mini-TimSort" with no merges
- if (nRemaining < MIN_MERGE) {
- int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
- binarySort(a, lo, hi, lo + initRunLen, c);
- return;
- }
-
- /**
- * March over the array once, left to right, finding natural runs,
- * extending short natural runs to minRun elements, and merging runs
- * to maintain stack invariant.
- */
- SortState sortState = new SortState(a, c, hi - lo);
- int minRun = minRunLength(nRemaining);
- do {
- // Identify next run
- int runLen = countRunAndMakeAscending(a, lo, hi, c);
-
- // If run is short, extend to min(minRun, nRemaining)
- if (runLen < minRun) {
- int force = nRemaining <= minRun ? nRemaining : minRun;
- binarySort(a, lo, lo + force, lo + runLen, c);
- runLen = force;
- }
-
- // Push run onto pending-run stack, and maybe merge
- sortState.pushRun(lo, runLen);
- sortState.mergeCollapse();
-
- // Advance to find next run
- lo += runLen;
- nRemaining -= runLen;
- } while (nRemaining != 0);
-
- // Merge all remaining runs to complete sort
- assert lo == hi;
- sortState.mergeForceCollapse();
- assert sortState.stackSize == 1;
- }
-
- /**
- * Sorts the specified portion of the specified array using a binary
- * insertion sort. This is the best method for sorting small numbers
- * of elements. It requires O(n log n) compares, but O(n^2) data
- * movement (worst case).
- *
- * If the initial part of the specified range is already sorted,
- * this method can take advantage of it: the method assumes that the
- * elements from index {@code lo}, inclusive, to {@code start},
- * exclusive are already sorted.
- *
- * @param a the array in which a range is to be sorted
- * @param lo the index of the first element in the range to be sorted
- * @param hi the index after the last element in the range to be sorted
- * @param start the index of the first element in the range that is
- * not already known to be sorted ({@code lo <= start <= hi})
- * @param c comparator to used for the sort
- */
- @SuppressWarnings("fallthrough")
- private void binarySort(Buffer a, int lo, int hi, int start, Comparator super K> c) {
- assert lo <= start && start <= hi;
- if (start == lo)
- start++;
-
- Buffer pivotStore = s.allocate(1);
- for ( ; start < hi; start++) {
- s.copyElement(a, start, pivotStore, 0);
- K pivot = s.getKey(pivotStore, 0);
-
- // Set left (and right) to the index where a[start] (pivot) belongs
- int left = lo;
- int right = start;
- assert left <= right;
- /*
- * Invariants:
- * pivot >= all in [lo, left).
- * pivot < all in [right, start).
- */
- while (left < right) {
- int mid = (left + right) >>> 1;
- if (c.compare(pivot, s.getKey(a, mid)) < 0)
- right = mid;
- else
- left = mid + 1;
- }
- assert left == right;
-
- /*
- * The invariants still hold: pivot >= all in [lo, left) and
- * pivot < all in [left, start), so pivot belongs at left. Note
- * that if there are elements equal to pivot, left points to the
- * first slot after them -- that's why this sort is stable.
- * Slide elements over to make room for pivot.
- */
- int n = start - left; // The number of elements to move
- // Switch is just an optimization for arraycopy in default case
- switch (n) {
- case 2: s.copyElement(a, left + 1, a, left + 2);
- case 1: s.copyElement(a, left, a, left + 1);
- break;
- default: s.copyRange(a, left, a, left + 1, n);
- }
- s.copyElement(pivotStore, 0, a, left);
- }
- }
-
- /**
- * Returns the length of the run beginning at the specified position in
- * the specified array and reverses the run if it is descending (ensuring
- * that the run will always be ascending when the method returns).
- *
- * A run is the longest ascending sequence with:
- *
- * a[lo] <= a[lo + 1] <= a[lo + 2] <= ...
- *
- * or the longest descending sequence with:
- *
- * a[lo] > a[lo + 1] > a[lo + 2] > ...
- *
- * For its intended use in a stable mergesort, the strictness of the
- * definition of "descending" is needed so that the call can safely
- * reverse a descending sequence without violating stability.
- *
- * @param a the array in which a run is to be counted and possibly reversed
- * @param lo index of the first element in the run
- * @param hi index after the last element that may be contained in the run.
- It is required that {@code lo < hi}.
- * @param c the comparator to used for the sort
- * @return the length of the run beginning at the specified position in
- * the specified array
- */
- private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator super K> c) {
- assert lo < hi;
- int runHi = lo + 1;
- if (runHi == hi)
- return 1;
-
- // Find end of run, and reverse range if descending
- if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0)
- runHi++;
- reverseRange(a, lo, runHi);
- } else { // Ascending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0)
- runHi++;
- }
-
- return runHi - lo;
- }
-
- /**
- * Reverse the specified range of the specified array.
- *
- * @param a the array in which a range is to be reversed
- * @param lo the index of the first element in the range to be reversed
- * @param hi the index after the last element in the range to be reversed
- */
- private void reverseRange(Buffer a, int lo, int hi) {
- hi--;
- while (lo < hi) {
- s.swap(a, lo, hi);
- lo++;
- hi--;
- }
- }
-
- /**
- * Returns the minimum acceptable run length for an array of the specified
- * length. Natural runs shorter than this will be extended with
- * {@link #binarySort}.
- *
- * Roughly speaking, the computation is:
- *
- * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
- * Else if n is an exact power of 2, return MIN_MERGE/2.
- * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
- * is close to, but strictly less than, an exact power of 2.
- *
- * For the rationale, see listsort.txt.
- *
- * @param n the length of the array to be sorted
- * @return the length of the minimum run to be merged
- */
- private int minRunLength(int n) {
- assert n >= 0;
- int r = 0; // Becomes 1 if any 1 bits are shifted off
- while (n >= MIN_MERGE) {
- r |= (n & 1);
- n >>= 1;
- }
- return n + r;
- }
-
- private class SortState {
-
- /**
- * The Buffer being sorted.
- */
- private final Buffer a;
-
- /**
- * Length of the sort Buffer.
- */
- private final int aLength;
-
- /**
- * The comparator for this sort.
- */
- private final Comparator super K> c;
-
- /**
- * When we get into galloping mode, we stay there until both runs win less
- * often than MIN_GALLOP consecutive times.
- */
- private static final int MIN_GALLOP = 7;
-
- /**
- * This controls when we get *into* galloping mode. It is initialized
- * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for
- * random data, and lower for highly structured data.
- */
- private int minGallop = MIN_GALLOP;
-
- /**
- * Maximum initial size of tmp array, which is used for merging. The array
- * can grow to accommodate demand.
- *
- * Unlike Tim's original C version, we do not allocate this much storage
- * when sorting smaller arrays. This change was required for performance.
- */
- private static final int INITIAL_TMP_STORAGE_LENGTH = 256;
-
- /**
- * Temp storage for merges.
- */
- private Buffer tmp; // Actual runtime type will be Object[], regardless of T
-
- /**
- * Length of the temp storage.
- */
- private int tmpLength = 0;
-
- /**
- * A stack of pending runs yet to be merged. Run i starts at
- * address base[i] and extends for len[i] elements. It's always
- * true (so long as the indices are in bounds) that:
- *
- * runBase[i] + runLen[i] == runBase[i + 1]
- *
- * so we could cut the storage for this, but it's a minor amount,
- * and keeping all the info explicit simplifies the code.
- */
- private int stackSize = 0; // Number of pending runs on stack
- private final int[] runBase;
- private final int[] runLen;
-
- /**
- * Creates a TimSort instance to maintain the state of an ongoing sort.
- *
- * @param a the array to be sorted
- * @param c the comparator to determine the order of the sort
- */
- private SortState(Buffer a, Comparator super K> c, int len) {
- this.aLength = len;
- this.a = a;
- this.c = c;
-
- // Allocate temp storage (which may be increased later if necessary)
- tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH;
- tmp = s.allocate(tmpLength);
-
- /*
- * Allocate runs-to-be-merged stack (which cannot be expanded). The
- * stack length requirements are described in listsort.txt. The C
- * version always uses the same stack length (85), but this was
- * measured to be too expensive when sorting "mid-sized" arrays (e.g.,
- * 100 elements) in Java. Therefore, we use smaller (but sufficiently
- * large) stack lengths for smaller arrays. The "magic numbers" in the
- * computation below must be changed if MIN_MERGE is decreased. See
- * the MIN_MERGE declaration above for more information.
- */
- int stackLen = (len < 120 ? 5 :
- len < 1542 ? 10 :
- len < 119151 ? 19 : 40);
- runBase = new int[stackLen];
- runLen = new int[stackLen];
- }
-
- /**
- * Pushes the specified run onto the pending-run stack.
- *
- * @param runBase index of the first element in the run
- * @param runLen the number of elements in the run
- */
- private void pushRun(int runBase, int runLen) {
- this.runBase[stackSize] = runBase;
- this.runLen[stackSize] = runLen;
- stackSize++;
- }
-
- /**
- * Examines the stack of runs waiting to be merged and merges adjacent runs
- * until the stack invariants are reestablished:
- *
- * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
- * 2. runLen[i - 2] > runLen[i - 1]
- *
- * This method is called each time a new run is pushed onto the stack,
- * so the invariants are guaranteed to hold for i < stackSize upon
- * entry to the method.
- */
- private void mergeCollapse() {
- while (stackSize > 1) {
- int n = stackSize - 2;
- if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
- if (runLen[n - 1] < runLen[n + 1])
- n--;
- mergeAt(n);
- } else if (runLen[n] <= runLen[n + 1]) {
- mergeAt(n);
- } else {
- break; // Invariant is established
- }
- }
- }
-
- /**
- * Merges all runs on the stack until only one remains. This method is
- * called once, to complete the sort.
- */
- private void mergeForceCollapse() {
- while (stackSize > 1) {
- int n = stackSize - 2;
- if (n > 0 && runLen[n - 1] < runLen[n + 1])
- n--;
- mergeAt(n);
- }
- }
-
- /**
- * Merges the two runs at stack indices i and i+1. Run i must be
- * the penultimate or antepenultimate run on the stack. In other words,
- * i must be equal to stackSize-2 or stackSize-3.
- *
- * @param i stack index of the first of the two runs to merge
- */
- private void mergeAt(int i) {
- assert stackSize >= 2;
- assert i >= 0;
- assert i == stackSize - 2 || i == stackSize - 3;
-
- int base1 = runBase[i];
- int len1 = runLen[i];
- int base2 = runBase[i + 1];
- int len2 = runLen[i + 1];
- assert len1 > 0 && len2 > 0;
- assert base1 + len1 == base2;
-
- /*
- * Record the length of the combined runs; if i is the 3rd-last
- * run now, also slide over the last run (which isn't involved
- * in this merge). The current run (i+1) goes away in any case.
- */
- runLen[i] = len1 + len2;
- if (i == stackSize - 3) {
- runBase[i + 1] = runBase[i + 2];
- runLen[i + 1] = runLen[i + 2];
- }
- stackSize--;
-
- /*
- * Find where the first element of run2 goes in run1. Prior elements
- * in run1 can be ignored (because they're already in place).
- */
- int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c);
- assert k >= 0;
- base1 += k;
- len1 -= k;
- if (len1 == 0)
- return;
-
- /*
- * Find where the last element of run1 goes in run2. Subsequent elements
- * in run2 can be ignored (because they're already in place).
- */
- len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c);
- assert len2 >= 0;
- if (len2 == 0)
- return;
-
- // Merge remaining runs, using tmp array with min(len1, len2) elements
- if (len1 <= len2)
- mergeLo(base1, len1, base2, len2);
- else
- mergeHi(base1, len1, base2, len2);
- }
-
- /**
- * Locates the position at which to insert the specified key into the
- * specified sorted range; if the range contains an element equal to key,
- * returns the index of the leftmost equal element.
- *
- * @param key the key whose insertion point to search for
- * @param a the array in which to search
- * @param base the index of the first element in the range
- * @param len the length of the range; must be > 0
- * @param hint the index at which to begin the search, 0 <= hint < n.
- * The closer hint is to the result, the faster this method will run.
- * @param c the comparator used to order the range, and to search
- * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k],
- * pretending that a[b - 1] is minus infinity and a[b + n] is infinity.
- * In other words, key belongs at index b + k; or in other words,
- * the first k elements of a should precede key, and the last n - k
- * should follow it.
- */
- private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
- assert len > 0 && hint >= 0 && hint < len;
- int lastOfs = 0;
- int ofs = 1;
- if (c.compare(key, s.getKey(a, base + hint)) > 0) {
- // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
- int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to base
- lastOfs += hint;
- ofs += hint;
- } else { // key <= a[base + hint]
- // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
- final int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to base
- int tmp = lastOfs;
- lastOfs = hint - ofs;
- ofs = hint - tmp;
- }
- assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
-
- /*
- * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere
- * to the right of lastOfs but no farther right than ofs. Do a binary
- * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs].
- */
- lastOfs++;
- while (lastOfs < ofs) {
- int m = lastOfs + ((ofs - lastOfs) >>> 1);
-
- if (c.compare(key, s.getKey(a, base + m)) > 0)
- lastOfs = m + 1; // a[base + m] < key
- else
- ofs = m; // key <= a[base + m]
- }
- assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs]
- return ofs;
- }
-
- /**
- * Like gallopLeft, except that if the range contains an element equal to
- * key, gallopRight returns the index after the rightmost equal element.
- *
- * @param key the key whose insertion point to search for
- * @param a the array in which to search
- * @param base the index of the first element in the range
- * @param len the length of the range; must be > 0
- * @param hint the index at which to begin the search, 0 <= hint < n.
- * The closer hint is to the result, the faster this method will run.
- * @param c the comparator used to order the range, and to search
- * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k]
- */
- private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
- assert len > 0 && hint >= 0 && hint < len;
-
- int ofs = 1;
- int lastOfs = 0;
- if (c.compare(key, s.getKey(a, base + hint)) < 0) {
- // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
- int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to b
- int tmp = lastOfs;
- lastOfs = hint - ofs;
- ofs = hint - tmp;
- } else { // a[b + hint] <= key
- // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
- int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to b
- lastOfs += hint;
- ofs += hint;
- }
- assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
-
- /*
- * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to
- * the right of lastOfs but no farther right than ofs. Do a binary
- * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs].
- */
- lastOfs++;
- while (lastOfs < ofs) {
- int m = lastOfs + ((ofs - lastOfs) >>> 1);
-
- if (c.compare(key, s.getKey(a, base + m)) < 0)
- ofs = m; // key < a[b + m]
- else
- lastOfs = m + 1; // a[b + m] <= key
- }
- assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs]
- return ofs;
- }
-
- /**
- * Merges two adjacent runs in place, in a stable fashion. The first
- * element of the first run must be greater than the first element of the
- * second run (a[base1] > a[base2]), and the last element of the first run
- * (a[base1 + len1-1]) must be greater than all elements of the second run.
- *
- * For performance, this method should be called only when len1 <= len2;
- * its twin, mergeHi should be called if len1 >= len2. (Either method
- * may be called if len1 == len2.)
- *
- * @param base1 index of first element in first run to be merged
- * @param len1 length of first run to be merged (must be > 0)
- * @param base2 index of first element in second run to be merged
- * (must be aBase + aLen)
- * @param len2 length of second run to be merged (must be > 0)
- */
- private void mergeLo(int base1, int len1, int base2, int len2) {
- assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
-
- // Copy first run into temp array
- Buffer a = this.a; // For performance
- Buffer tmp = ensureCapacity(len1);
- s.copyRange(a, base1, tmp, 0, len1);
-
- int cursor1 = 0; // Indexes into tmp array
- int cursor2 = base2; // Indexes int a
- int dest = base1; // Indexes int a
-
- // Move first element of second run and deal with degenerate cases
- s.copyElement(a, cursor2++, a, dest++);
- if (--len2 == 0) {
- s.copyRange(tmp, cursor1, a, dest, len1);
- return;
- }
- if (len1 == 1) {
- s.copyRange(a, cursor2, a, dest, len2);
- s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
- return;
- }
-
- Comparator super K> c = this.c; // Use local variable for performance
- int minGallop = this.minGallop; // " " " " "
- outer:
- while (true) {
- int count1 = 0; // Number of times in a row that first run won
- int count2 = 0; // Number of times in a row that second run won
-
- /*
- * Do the straightforward thing until (if ever) one run starts
- * winning consistently.
- */
- do {
- assert len1 > 1 && len2 > 0;
- if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) {
- s.copyElement(a, cursor2++, a, dest++);
- count2++;
- count1 = 0;
- if (--len2 == 0)
- break outer;
- } else {
- s.copyElement(tmp, cursor1++, a, dest++);
- count1++;
- count2 = 0;
- if (--len1 == 1)
- break outer;
- }
- } while ((count1 | count2) < minGallop);
-
- /*
- * One run is winning so consistently that galloping may be a
- * huge win. So try that, and continue galloping until (if ever)
- * neither run appears to be winning consistently anymore.
- */
- do {
- assert len1 > 1 && len2 > 0;
- count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c);
- if (count1 != 0) {
- s.copyRange(tmp, cursor1, a, dest, count1);
- dest += count1;
- cursor1 += count1;
- len1 -= count1;
- if (len1 <= 1) // len1 == 1 || len1 == 0
- break outer;
- }
- s.copyElement(a, cursor2++, a, dest++);
- if (--len2 == 0)
- break outer;
-
- count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c);
- if (count2 != 0) {
- s.copyRange(a, cursor2, a, dest, count2);
- dest += count2;
- cursor2 += count2;
- len2 -= count2;
- if (len2 == 0)
- break outer;
- }
- s.copyElement(tmp, cursor1++, a, dest++);
- if (--len1 == 1)
- break outer;
- minGallop--;
- } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
- if (minGallop < 0)
- minGallop = 0;
- minGallop += 2; // Penalize for leaving gallop mode
- } // End of "outer" loop
- this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
-
- if (len1 == 1) {
- assert len2 > 0;
- s.copyRange(a, cursor2, a, dest, len2);
- s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
- } else if (len1 == 0) {
- throw new IllegalArgumentException(
- "Comparison method violates its general contract!");
- } else {
- assert len2 == 0;
- assert len1 > 1;
- s.copyRange(tmp, cursor1, a, dest, len1);
- }
- }
-
- /**
- * Like mergeLo, except that this method should be called only if
- * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method
- * may be called if len1 == len2.)
- *
- * @param base1 index of first element in first run to be merged
- * @param len1 length of first run to be merged (must be > 0)
- * @param base2 index of first element in second run to be merged
- * (must be aBase + aLen)
- * @param len2 length of second run to be merged (must be > 0)
- */
- private void mergeHi(int base1, int len1, int base2, int len2) {
- assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
-
- // Copy second run into temp array
- Buffer a = this.a; // For performance
- Buffer tmp = ensureCapacity(len2);
- s.copyRange(a, base2, tmp, 0, len2);
-
- int cursor1 = base1 + len1 - 1; // Indexes into a
- int cursor2 = len2 - 1; // Indexes into tmp array
- int dest = base2 + len2 - 1; // Indexes into a
-
- // Move last element of first run and deal with degenerate cases
- s.copyElement(a, cursor1--, a, dest--);
- if (--len1 == 0) {
- s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
- return;
- }
- if (len2 == 1) {
- dest -= len1;
- cursor1 -= len1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
- s.copyElement(tmp, cursor2, a, dest);
- return;
- }
-
- Comparator super K> c = this.c; // Use local variable for performance
- int minGallop = this.minGallop; // " " " " "
- outer:
- while (true) {
- int count1 = 0; // Number of times in a row that first run won
- int count2 = 0; // Number of times in a row that second run won
-
- /*
- * Do the straightforward thing until (if ever) one run
- * appears to win consistently.
- */
- do {
- assert len1 > 0 && len2 > 1;
- if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) {
- s.copyElement(a, cursor1--, a, dest--);
- count1++;
- count2 = 0;
- if (--len1 == 0)
- break outer;
- } else {
- s.copyElement(tmp, cursor2--, a, dest--);
- count2++;
- count1 = 0;
- if (--len2 == 1)
- break outer;
- }
- } while ((count1 | count2) < minGallop);
-
- /*
- * One run is winning so consistently that galloping may be a
- * huge win. So try that, and continue galloping until (if ever)
- * neither run appears to be winning consistently anymore.
- */
- do {
- assert len1 > 0 && len2 > 1;
- count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c);
- if (count1 != 0) {
- dest -= count1;
- cursor1 -= count1;
- len1 -= count1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, count1);
- if (len1 == 0)
- break outer;
- }
- s.copyElement(tmp, cursor2--, a, dest--);
- if (--len2 == 1)
- break outer;
-
- count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c);
- if (count2 != 0) {
- dest -= count2;
- cursor2 -= count2;
- len2 -= count2;
- s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2);
- if (len2 <= 1) // len2 == 1 || len2 == 0
- break outer;
- }
- s.copyElement(a, cursor1--, a, dest--);
- if (--len1 == 0)
- break outer;
- minGallop--;
- } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
- if (minGallop < 0)
- minGallop = 0;
- minGallop += 2; // Penalize for leaving gallop mode
- } // End of "outer" loop
- this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
-
- if (len2 == 1) {
- assert len1 > 0;
- dest -= len1;
- cursor1 -= len1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
- s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge
- } else if (len2 == 0) {
- throw new IllegalArgumentException(
- "Comparison method violates its general contract!");
- } else {
- assert len1 == 0;
- assert len2 > 0;
- s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
- }
- }
-
- /**
- * Ensures that the external array tmp has at least the specified
- * number of elements, increasing its size if necessary. The size
- * increases exponentially to ensure amortized linear time complexity.
- *
- * @param minCapacity the minimum required capacity of the tmp array
- * @return tmp, whether or not it grew
- */
- private Buffer ensureCapacity(int minCapacity) {
- if (tmpLength < minCapacity) {
- // Compute smallest power of 2 > minCapacity
- int newSize = minCapacity;
- newSize |= newSize >> 1;
- newSize |= newSize >> 2;
- newSize |= newSize >> 4;
- newSize |= newSize >> 8;
- newSize |= newSize >> 16;
- newSize++;
-
- if (newSize < 0) // Not bloody likely!
- newSize = minCapacity;
- else
- newSize = Math.min(newSize, aLength >>> 1);
-
- tmp = s.allocate(newSize);
- tmpLength = newSize;
- }
- return tmp;
- }
- }
-}
diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
new file mode 100644
index 0000000000000..409e1a41c5d49
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
@@ -0,0 +1,940 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection;
+
+import java.util.Comparator;
+
+/**
+ * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort."
+ * See the method comment on sort() for more details.
+ *
+ * This has been kept in Java with the original style in order to match very closely with the
+ * Android source code, and thus be easy to verify correctness. The class is package private. We put
+ * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to
+ * package org.apache.spark.
+ *
+ * The purpose of the port is to generalize the interface to the sort to accept input data formats
+ * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
+ * uses this to sort an Array with alternating elements of the form [key, value, key, value].
+ * This generalization comes with minimal overhead -- see SortDataFormat for more information.
+ *
+ * We allow key reuse to prevent creating many key objects -- see SortDataFormat.
+ *
+ * @see org.apache.spark.util.collection.SortDataFormat
+ * @see org.apache.spark.util.collection.Sorter
+ */
+class TimSort {
+
+ /**
+ * This is the minimum sized sequence that will be merged. Shorter
+ * sequences will be lengthened by calling binarySort. If the entire
+ * array is less than this length, no merges will be performed.
+ *
+ * This constant should be a power of two. It was 64 in Tim Peter's C
+ * implementation, but 32 was empirically determined to work better in
+ * this implementation. In the unlikely event that you set this constant
+ * to be a number that's not a power of two, you'll need to change the
+ * minRunLength computation.
+ *
+ * If you decrease this constant, you must change the stackLen
+ * computation in the TimSort constructor, or you risk an
+ * ArrayOutOfBounds exception. See listsort.txt for a discussion
+ * of the minimum stack length required as a function of the length
+ * of the array being sorted and the minimum merge sequence length.
+ */
+ private static final int MIN_MERGE = 32;
+
+ private final SortDataFormat s;
+
+ public TimSort(SortDataFormat sortDataFormat) {
+ this.s = sortDataFormat;
+ }
+
+ /**
+ * A stable, adaptive, iterative mergesort that requires far fewer than
+ * n lg(n) comparisons when running on partially sorted arrays, while
+ * offering performance comparable to a traditional mergesort when run
+ * on random arrays. Like all proper mergesorts, this sort is stable and
+ * runs O(n log n) time (worst case). In the worst case, this sort requires
+ * temporary storage space for n/2 object references; in the best case,
+ * it requires only a small constant amount of space.
+ *
+ * This implementation was adapted from Tim Peters's list sort for
+ * Python, which is described in detail here:
+ *
+ * http://svn.python.org/projects/python/trunk/Objects/listsort.txt
+ *
+ * Tim's C code may be found here:
+ *
+ * http://svn.python.org/projects/python/trunk/Objects/listobject.c
+ *
+ * The underlying techniques are described in this paper (and may have
+ * even earlier origins):
+ *
+ * "Optimistic Sorting and Information Theoretic Complexity"
+ * Peter McIlroy
+ * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms),
+ * pp 467-474, Austin, Texas, 25-27 January 1993.
+ *
+ * While the API to this class consists solely of static methods, it is
+ * (privately) instantiable; a TimSort instance holds the state of an ongoing
+ * sort, assuming the input array is large enough to warrant the full-blown
+ * TimSort. Small arrays are sorted in place, using a binary insertion sort.
+ *
+ * @author Josh Bloch
+ */
+ public void sort(Buffer a, int lo, int hi, Comparator super K> c) {
+ assert c != null;
+
+ int nRemaining = hi - lo;
+ if (nRemaining < 2)
+ return; // Arrays of size 0 and 1 are always sorted
+
+ // If array is small, do a "mini-TimSort" with no merges
+ if (nRemaining < MIN_MERGE) {
+ int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
+ binarySort(a, lo, hi, lo + initRunLen, c);
+ return;
+ }
+
+ /**
+ * March over the array once, left to right, finding natural runs,
+ * extending short natural runs to minRun elements, and merging runs
+ * to maintain stack invariant.
+ */
+ SortState sortState = new SortState(a, c, hi - lo);
+ int minRun = minRunLength(nRemaining);
+ do {
+ // Identify next run
+ int runLen = countRunAndMakeAscending(a, lo, hi, c);
+
+ // If run is short, extend to min(minRun, nRemaining)
+ if (runLen < minRun) {
+ int force = nRemaining <= minRun ? nRemaining : minRun;
+ binarySort(a, lo, lo + force, lo + runLen, c);
+ runLen = force;
+ }
+
+ // Push run onto pending-run stack, and maybe merge
+ sortState.pushRun(lo, runLen);
+ sortState.mergeCollapse();
+
+ // Advance to find next run
+ lo += runLen;
+ nRemaining -= runLen;
+ } while (nRemaining != 0);
+
+ // Merge all remaining runs to complete sort
+ assert lo == hi;
+ sortState.mergeForceCollapse();
+ assert sortState.stackSize == 1;
+ }
+
+ /**
+ * Sorts the specified portion of the specified array using a binary
+ * insertion sort. This is the best method for sorting small numbers
+ * of elements. It requires O(n log n) compares, but O(n^2) data
+ * movement (worst case).
+ *
+ * If the initial part of the specified range is already sorted,
+ * this method can take advantage of it: the method assumes that the
+ * elements from index {@code lo}, inclusive, to {@code start},
+ * exclusive are already sorted.
+ *
+ * @param a the array in which a range is to be sorted
+ * @param lo the index of the first element in the range to be sorted
+ * @param hi the index after the last element in the range to be sorted
+ * @param start the index of the first element in the range that is
+ * not already known to be sorted ({@code lo <= start <= hi})
+ * @param c comparator to used for the sort
+ */
+ @SuppressWarnings("fallthrough")
+ private void binarySort(Buffer a, int lo, int hi, int start, Comparator super K> c) {
+ assert lo <= start && start <= hi;
+ if (start == lo)
+ start++;
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ Buffer pivotStore = s.allocate(1);
+ for ( ; start < hi; start++) {
+ s.copyElement(a, start, pivotStore, 0);
+ K pivot = s.getKey(pivotStore, 0, key0);
+
+ // Set left (and right) to the index where a[start] (pivot) belongs
+ int left = lo;
+ int right = start;
+ assert left <= right;
+ /*
+ * Invariants:
+ * pivot >= all in [lo, left).
+ * pivot < all in [right, start).
+ */
+ while (left < right) {
+ int mid = (left + right) >>> 1;
+ if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
+ right = mid;
+ else
+ left = mid + 1;
+ }
+ assert left == right;
+
+ /*
+ * The invariants still hold: pivot >= all in [lo, left) and
+ * pivot < all in [left, start), so pivot belongs at left. Note
+ * that if there are elements equal to pivot, left points to the
+ * first slot after them -- that's why this sort is stable.
+ * Slide elements over to make room for pivot.
+ */
+ int n = start - left; // The number of elements to move
+ // Switch is just an optimization for arraycopy in default case
+ switch (n) {
+ case 2: s.copyElement(a, left + 1, a, left + 2);
+ case 1: s.copyElement(a, left, a, left + 1);
+ break;
+ default: s.copyRange(a, left, a, left + 1, n);
+ }
+ s.copyElement(pivotStore, 0, a, left);
+ }
+ }
+
+ /**
+ * Returns the length of the run beginning at the specified position in
+ * the specified array and reverses the run if it is descending (ensuring
+ * that the run will always be ascending when the method returns).
+ *
+ * A run is the longest ascending sequence with:
+ *
+ * a[lo] <= a[lo + 1] <= a[lo + 2] <= ...
+ *
+ * or the longest descending sequence with:
+ *
+ * a[lo] > a[lo + 1] > a[lo + 2] > ...
+ *
+ * For its intended use in a stable mergesort, the strictness of the
+ * definition of "descending" is needed so that the call can safely
+ * reverse a descending sequence without violating stability.
+ *
+ * @param a the array in which a run is to be counted and possibly reversed
+ * @param lo index of the first element in the run
+ * @param hi index after the last element that may be contained in the run.
+ It is required that {@code lo < hi}.
+ * @param c the comparator to used for the sort
+ * @return the length of the run beginning at the specified position in
+ * the specified array
+ */
+ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator super K> c) {
+ assert lo < hi;
+ int runHi = lo + 1;
+ if (runHi == hi)
+ return 1;
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ // Find end of run, and reverse range if descending
+ if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)
+ runHi++;
+ reverseRange(a, lo, runHi);
+ } else { // Ascending
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)
+ runHi++;
+ }
+
+ return runHi - lo;
+ }
+
+ /**
+ * Reverse the specified range of the specified array.
+ *
+ * @param a the array in which a range is to be reversed
+ * @param lo the index of the first element in the range to be reversed
+ * @param hi the index after the last element in the range to be reversed
+ */
+ private void reverseRange(Buffer a, int lo, int hi) {
+ hi--;
+ while (lo < hi) {
+ s.swap(a, lo, hi);
+ lo++;
+ hi--;
+ }
+ }
+
+ /**
+ * Returns the minimum acceptable run length for an array of the specified
+ * length. Natural runs shorter than this will be extended with
+ * {@link #binarySort}.
+ *
+ * Roughly speaking, the computation is:
+ *
+ * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
+ * Else if n is an exact power of 2, return MIN_MERGE/2.
+ * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
+ * is close to, but strictly less than, an exact power of 2.
+ *
+ * For the rationale, see listsort.txt.
+ *
+ * @param n the length of the array to be sorted
+ * @return the length of the minimum run to be merged
+ */
+ private int minRunLength(int n) {
+ assert n >= 0;
+ int r = 0; // Becomes 1 if any 1 bits are shifted off
+ while (n >= MIN_MERGE) {
+ r |= (n & 1);
+ n >>= 1;
+ }
+ return n + r;
+ }
+
+ private class SortState {
+
+ /**
+ * The Buffer being sorted.
+ */
+ private final Buffer a;
+
+ /**
+ * Length of the sort Buffer.
+ */
+ private final int aLength;
+
+ /**
+ * The comparator for this sort.
+ */
+ private final Comparator super K> c;
+
+ /**
+ * When we get into galloping mode, we stay there until both runs win less
+ * often than MIN_GALLOP consecutive times.
+ */
+ private static final int MIN_GALLOP = 7;
+
+ /**
+ * This controls when we get *into* galloping mode. It is initialized
+ * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for
+ * random data, and lower for highly structured data.
+ */
+ private int minGallop = MIN_GALLOP;
+
+ /**
+ * Maximum initial size of tmp array, which is used for merging. The array
+ * can grow to accommodate demand.
+ *
+ * Unlike Tim's original C version, we do not allocate this much storage
+ * when sorting smaller arrays. This change was required for performance.
+ */
+ private static final int INITIAL_TMP_STORAGE_LENGTH = 256;
+
+ /**
+ * Temp storage for merges.
+ */
+ private Buffer tmp; // Actual runtime type will be Object[], regardless of T
+
+ /**
+ * Length of the temp storage.
+ */
+ private int tmpLength = 0;
+
+ /**
+ * A stack of pending runs yet to be merged. Run i starts at
+ * address base[i] and extends for len[i] elements. It's always
+ * true (so long as the indices are in bounds) that:
+ *
+ * runBase[i] + runLen[i] == runBase[i + 1]
+ *
+ * so we could cut the storage for this, but it's a minor amount,
+ * and keeping all the info explicit simplifies the code.
+ */
+ private int stackSize = 0; // Number of pending runs on stack
+ private final int[] runBase;
+ private final int[] runLen;
+
+ /**
+ * Creates a TimSort instance to maintain the state of an ongoing sort.
+ *
+ * @param a the array to be sorted
+ * @param c the comparator to determine the order of the sort
+ */
+ private SortState(Buffer a, Comparator super K> c, int len) {
+ this.aLength = len;
+ this.a = a;
+ this.c = c;
+
+ // Allocate temp storage (which may be increased later if necessary)
+ tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH;
+ tmp = s.allocate(tmpLength);
+
+ /*
+ * Allocate runs-to-be-merged stack (which cannot be expanded). The
+ * stack length requirements are described in listsort.txt. The C
+ * version always uses the same stack length (85), but this was
+ * measured to be too expensive when sorting "mid-sized" arrays (e.g.,
+ * 100 elements) in Java. Therefore, we use smaller (but sufficiently
+ * large) stack lengths for smaller arrays. The "magic numbers" in the
+ * computation below must be changed if MIN_MERGE is decreased. See
+ * the MIN_MERGE declaration above for more information.
+ */
+ int stackLen = (len < 120 ? 5 :
+ len < 1542 ? 10 :
+ len < 119151 ? 19 : 40);
+ runBase = new int[stackLen];
+ runLen = new int[stackLen];
+ }
+
+ /**
+ * Pushes the specified run onto the pending-run stack.
+ *
+ * @param runBase index of the first element in the run
+ * @param runLen the number of elements in the run
+ */
+ private void pushRun(int runBase, int runLen) {
+ this.runBase[stackSize] = runBase;
+ this.runLen[stackSize] = runLen;
+ stackSize++;
+ }
+
+ /**
+ * Examines the stack of runs waiting to be merged and merges adjacent runs
+ * until the stack invariants are reestablished:
+ *
+ * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
+ * 2. runLen[i - 2] > runLen[i - 1]
+ *
+ * This method is called each time a new run is pushed onto the stack,
+ * so the invariants are guaranteed to hold for i < stackSize upon
+ * entry to the method.
+ */
+ private void mergeCollapse() {
+ while (stackSize > 1) {
+ int n = stackSize - 2;
+ if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
+ if (runLen[n - 1] < runLen[n + 1])
+ n--;
+ mergeAt(n);
+ } else if (runLen[n] <= runLen[n + 1]) {
+ mergeAt(n);
+ } else {
+ break; // Invariant is established
+ }
+ }
+ }
+
+ /**
+ * Merges all runs on the stack until only one remains. This method is
+ * called once, to complete the sort.
+ */
+ private void mergeForceCollapse() {
+ while (stackSize > 1) {
+ int n = stackSize - 2;
+ if (n > 0 && runLen[n - 1] < runLen[n + 1])
+ n--;
+ mergeAt(n);
+ }
+ }
+
+ /**
+ * Merges the two runs at stack indices i and i+1. Run i must be
+ * the penultimate or antepenultimate run on the stack. In other words,
+ * i must be equal to stackSize-2 or stackSize-3.
+ *
+ * @param i stack index of the first of the two runs to merge
+ */
+ private void mergeAt(int i) {
+ assert stackSize >= 2;
+ assert i >= 0;
+ assert i == stackSize - 2 || i == stackSize - 3;
+
+ int base1 = runBase[i];
+ int len1 = runLen[i];
+ int base2 = runBase[i + 1];
+ int len2 = runLen[i + 1];
+ assert len1 > 0 && len2 > 0;
+ assert base1 + len1 == base2;
+
+ /*
+ * Record the length of the combined runs; if i is the 3rd-last
+ * run now, also slide over the last run (which isn't involved
+ * in this merge). The current run (i+1) goes away in any case.
+ */
+ runLen[i] = len1 + len2;
+ if (i == stackSize - 3) {
+ runBase[i + 1] = runBase[i + 2];
+ runLen[i + 1] = runLen[i + 2];
+ }
+ stackSize--;
+
+ K key0 = s.newKey();
+
+ /*
+ * Find where the first element of run2 goes in run1. Prior elements
+ * in run1 can be ignored (because they're already in place).
+ */
+ int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
+ assert k >= 0;
+ base1 += k;
+ len1 -= k;
+ if (len1 == 0)
+ return;
+
+ /*
+ * Find where the last element of run1 goes in run2. Subsequent elements
+ * in run2 can be ignored (because they're already in place).
+ */
+ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
+ assert len2 >= 0;
+ if (len2 == 0)
+ return;
+
+ // Merge remaining runs, using tmp array with min(len1, len2) elements
+ if (len1 <= len2)
+ mergeLo(base1, len1, base2, len2);
+ else
+ mergeHi(base1, len1, base2, len2);
+ }
+
+ /**
+ * Locates the position at which to insert the specified key into the
+ * specified sorted range; if the range contains an element equal to key,
+ * returns the index of the leftmost equal element.
+ *
+ * @param key the key whose insertion point to search for
+ * @param a the array in which to search
+ * @param base the index of the first element in the range
+ * @param len the length of the range; must be > 0
+ * @param hint the index at which to begin the search, 0 <= hint < n.
+ * The closer hint is to the result, the faster this method will run.
+ * @param c the comparator used to order the range, and to search
+ * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k],
+ * pretending that a[b - 1] is minus infinity and a[b + n] is infinity.
+ * In other words, key belongs at index b + k; or in other words,
+ * the first k elements of a should precede key, and the last n - k
+ * should follow it.
+ */
+ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
+ assert len > 0 && hint >= 0 && hint < len;
+ int lastOfs = 0;
+ int ofs = 1;
+ K key0 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) {
+ // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
+ int maxOfs = len - hint;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to base
+ lastOfs += hint;
+ ofs += hint;
+ } else { // key <= a[base + hint]
+ // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
+ final int maxOfs = hint + 1;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to base
+ int tmp = lastOfs;
+ lastOfs = hint - ofs;
+ ofs = hint - tmp;
+ }
+ assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
+
+ /*
+ * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere
+ * to the right of lastOfs but no farther right than ofs. Do a binary
+ * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs].
+ */
+ lastOfs++;
+ while (lastOfs < ofs) {
+ int m = lastOfs + ((ofs - lastOfs) >>> 1);
+
+ if (c.compare(key, s.getKey(a, base + m, key0)) > 0)
+ lastOfs = m + 1; // a[base + m] < key
+ else
+ ofs = m; // key <= a[base + m]
+ }
+ assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs]
+ return ofs;
+ }
+
+ /**
+ * Like gallopLeft, except that if the range contains an element equal to
+ * key, gallopRight returns the index after the rightmost equal element.
+ *
+ * @param key the key whose insertion point to search for
+ * @param a the array in which to search
+ * @param base the index of the first element in the range
+ * @param len the length of the range; must be > 0
+ * @param hint the index at which to begin the search, 0 <= hint < n.
+ * The closer hint is to the result, the faster this method will run.
+ * @param c the comparator used to order the range, and to search
+ * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k]
+ */
+ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
+ assert len > 0 && hint >= 0 && hint < len;
+
+ int ofs = 1;
+ int lastOfs = 0;
+ K key1 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {
+ // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
+ int maxOfs = hint + 1;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to b
+ int tmp = lastOfs;
+ lastOfs = hint - ofs;
+ ofs = hint - tmp;
+ } else { // a[b + hint] <= key
+ // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
+ int maxOfs = len - hint;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to b
+ lastOfs += hint;
+ ofs += hint;
+ }
+ assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
+
+ /*
+ * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to
+ * the right of lastOfs but no farther right than ofs. Do a binary
+ * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs].
+ */
+ lastOfs++;
+ while (lastOfs < ofs) {
+ int m = lastOfs + ((ofs - lastOfs) >>> 1);
+
+ if (c.compare(key, s.getKey(a, base + m, key1)) < 0)
+ ofs = m; // key < a[b + m]
+ else
+ lastOfs = m + 1; // a[b + m] <= key
+ }
+ assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs]
+ return ofs;
+ }
+
+ /**
+ * Merges two adjacent runs in place, in a stable fashion. The first
+ * element of the first run must be greater than the first element of the
+ * second run (a[base1] > a[base2]), and the last element of the first run
+ * (a[base1 + len1-1]) must be greater than all elements of the second run.
+ *
+ * For performance, this method should be called only when len1 <= len2;
+ * its twin, mergeHi should be called if len1 >= len2. (Either method
+ * may be called if len1 == len2.)
+ *
+ * @param base1 index of first element in first run to be merged
+ * @param len1 length of first run to be merged (must be > 0)
+ * @param base2 index of first element in second run to be merged
+ * (must be aBase + aLen)
+ * @param len2 length of second run to be merged (must be > 0)
+ */
+ private void mergeLo(int base1, int len1, int base2, int len2) {
+ assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
+
+ // Copy first run into temp array
+ Buffer a = this.a; // For performance
+ Buffer tmp = ensureCapacity(len1);
+ s.copyRange(a, base1, tmp, 0, len1);
+
+ int cursor1 = 0; // Indexes into tmp array
+ int cursor2 = base2; // Indexes int a
+ int dest = base1; // Indexes int a
+
+ // Move first element of second run and deal with degenerate cases
+ s.copyElement(a, cursor2++, a, dest++);
+ if (--len2 == 0) {
+ s.copyRange(tmp, cursor1, a, dest, len1);
+ return;
+ }
+ if (len1 == 1) {
+ s.copyRange(a, cursor2, a, dest, len2);
+ s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
+ return;
+ }
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ Comparator super K> c = this.c; // Use local variable for performance
+ int minGallop = this.minGallop; // " " " " "
+ outer:
+ while (true) {
+ int count1 = 0; // Number of times in a row that first run won
+ int count2 = 0; // Number of times in a row that second run won
+
+ /*
+ * Do the straightforward thing until (if ever) one run starts
+ * winning consistently.
+ */
+ do {
+ assert len1 > 1 && len2 > 0;
+ if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {
+ s.copyElement(a, cursor2++, a, dest++);
+ count2++;
+ count1 = 0;
+ if (--len2 == 0)
+ break outer;
+ } else {
+ s.copyElement(tmp, cursor1++, a, dest++);
+ count1++;
+ count2 = 0;
+ if (--len1 == 1)
+ break outer;
+ }
+ } while ((count1 | count2) < minGallop);
+
+ /*
+ * One run is winning so consistently that galloping may be a
+ * huge win. So try that, and continue galloping until (if ever)
+ * neither run appears to be winning consistently anymore.
+ */
+ do {
+ assert len1 > 1 && len2 > 0;
+ count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);
+ if (count1 != 0) {
+ s.copyRange(tmp, cursor1, a, dest, count1);
+ dest += count1;
+ cursor1 += count1;
+ len1 -= count1;
+ if (len1 <= 1) // len1 == 1 || len1 == 0
+ break outer;
+ }
+ s.copyElement(a, cursor2++, a, dest++);
+ if (--len2 == 0)
+ break outer;
+
+ count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);
+ if (count2 != 0) {
+ s.copyRange(a, cursor2, a, dest, count2);
+ dest += count2;
+ cursor2 += count2;
+ len2 -= count2;
+ if (len2 == 0)
+ break outer;
+ }
+ s.copyElement(tmp, cursor1++, a, dest++);
+ if (--len1 == 1)
+ break outer;
+ minGallop--;
+ } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
+ if (minGallop < 0)
+ minGallop = 0;
+ minGallop += 2; // Penalize for leaving gallop mode
+ } // End of "outer" loop
+ this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
+
+ if (len1 == 1) {
+ assert len2 > 0;
+ s.copyRange(a, cursor2, a, dest, len2);
+ s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
+ } else if (len1 == 0) {
+ throw new IllegalArgumentException(
+ "Comparison method violates its general contract!");
+ } else {
+ assert len2 == 0;
+ assert len1 > 1;
+ s.copyRange(tmp, cursor1, a, dest, len1);
+ }
+ }
+
+ /**
+ * Like mergeLo, except that this method should be called only if
+ * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method
+ * may be called if len1 == len2.)
+ *
+ * @param base1 index of first element in first run to be merged
+ * @param len1 length of first run to be merged (must be > 0)
+ * @param base2 index of first element in second run to be merged
+ * (must be aBase + aLen)
+ * @param len2 length of second run to be merged (must be > 0)
+ */
+ private void mergeHi(int base1, int len1, int base2, int len2) {
+ assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
+
+ // Copy second run into temp array
+ Buffer a = this.a; // For performance
+ Buffer tmp = ensureCapacity(len2);
+ s.copyRange(a, base2, tmp, 0, len2);
+
+ int cursor1 = base1 + len1 - 1; // Indexes into a
+ int cursor2 = len2 - 1; // Indexes into tmp array
+ int dest = base2 + len2 - 1; // Indexes into a
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ // Move last element of first run and deal with degenerate cases
+ s.copyElement(a, cursor1--, a, dest--);
+ if (--len1 == 0) {
+ s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
+ return;
+ }
+ if (len2 == 1) {
+ dest -= len1;
+ cursor1 -= len1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
+ s.copyElement(tmp, cursor2, a, dest);
+ return;
+ }
+
+ Comparator super K> c = this.c; // Use local variable for performance
+ int minGallop = this.minGallop; // " " " " "
+ outer:
+ while (true) {
+ int count1 = 0; // Number of times in a row that first run won
+ int count2 = 0; // Number of times in a row that second run won
+
+ /*
+ * Do the straightforward thing until (if ever) one run
+ * appears to win consistently.
+ */
+ do {
+ assert len1 > 0 && len2 > 1;
+ if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) {
+ s.copyElement(a, cursor1--, a, dest--);
+ count1++;
+ count2 = 0;
+ if (--len1 == 0)
+ break outer;
+ } else {
+ s.copyElement(tmp, cursor2--, a, dest--);
+ count2++;
+ count1 = 0;
+ if (--len2 == 1)
+ break outer;
+ }
+ } while ((count1 | count2) < minGallop);
+
+ /*
+ * One run is winning so consistently that galloping may be a
+ * huge win. So try that, and continue galloping until (if ever)
+ * neither run appears to be winning consistently anymore.
+ */
+ do {
+ assert len1 > 0 && len2 > 1;
+ count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c);
+ if (count1 != 0) {
+ dest -= count1;
+ cursor1 -= count1;
+ len1 -= count1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, count1);
+ if (len1 == 0)
+ break outer;
+ }
+ s.copyElement(tmp, cursor2--, a, dest--);
+ if (--len2 == 1)
+ break outer;
+
+ count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c);
+ if (count2 != 0) {
+ dest -= count2;
+ cursor2 -= count2;
+ len2 -= count2;
+ s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2);
+ if (len2 <= 1) // len2 == 1 || len2 == 0
+ break outer;
+ }
+ s.copyElement(a, cursor1--, a, dest--);
+ if (--len1 == 0)
+ break outer;
+ minGallop--;
+ } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
+ if (minGallop < 0)
+ minGallop = 0;
+ minGallop += 2; // Penalize for leaving gallop mode
+ } // End of "outer" loop
+ this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
+
+ if (len2 == 1) {
+ assert len1 > 0;
+ dest -= len1;
+ cursor1 -= len1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
+ s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge
+ } else if (len2 == 0) {
+ throw new IllegalArgumentException(
+ "Comparison method violates its general contract!");
+ } else {
+ assert len1 == 0;
+ assert len2 > 0;
+ s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
+ }
+ }
+
+ /**
+ * Ensures that the external array tmp has at least the specified
+ * number of elements, increasing its size if necessary. The size
+ * increases exponentially to ensure amortized linear time complexity.
+ *
+ * @param minCapacity the minimum required capacity of the tmp array
+ * @return tmp, whether or not it grew
+ */
+ private Buffer ensureCapacity(int minCapacity) {
+ if (tmpLength < minCapacity) {
+ // Compute smallest power of 2 > minCapacity
+ int newSize = minCapacity;
+ newSize |= newSize >> 1;
+ newSize |= newSize >> 2;
+ newSize |= newSize >> 4;
+ newSize |= newSize >> 8;
+ newSize |= newSize >> 16;
+ newSize++;
+
+ if (newSize < 0) // Not bloody likely!
+ newSize = minCapacity;
+ else
+ newSize = Math.min(newSize, aLength >>> 1);
+
+ tmp = s.allocate(newSize);
+ tmpLength = newSize;
+ }
+ return tmp;
+ }
+ }
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
new file mode 100644
index 0000000000000..14ba37d7c9bd9
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Register functions to show/hide columns based on checkboxes. These need
+ * to be registered after the page loads. */
+$(function() {
+ $("span.expand-additional-metrics").click(function(){
+ // Expand the list of additional metrics.
+ var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
+ $(additionalMetricsDiv).toggleClass('collapsed');
+
+ // Switch the class of the arrow from open to closed.
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
+ });
+
+ stripeSummaryTable();
+
+ $("input:checkbox").click(function() {
+ var column = "table ." + $(this).attr("name");
+ $(column).toggle();
+ stripeSummaryTable();
+ });
+
+ $("#select-all-metrics").click(function() {
+ if (this.checked) {
+ // Toggle all un-checked options.
+ $('input:checkbox:not(:checked)').trigger('click');
+ } else {
+ // Toggle all checked options.
+ $('input:checkbox:checked').trigger('click');
+ }
+ });
+
+ // Trigger a click on the checkbox if a user clicks the label next to it.
+ $("span.additional-metric-title").click(function() {
+ $(this).parent().find('input:checkbox').trigger('click');
+ });
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js
new file mode 100644
index 0000000000000..656147e40d13e
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/table.js
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Adds background colors to stripe table rows in the summary table (on the stage page). This is
+ * necessary (instead of using css or the table striping provided by bootstrap) because the summary
+ * table has hidden rows.
+ *
+ * An ID selector (rather than a class selector) is used to ensure this runs quickly even on pages
+ * with thousands of task rows (ID selectors are much faster than class selectors). */
+function stripeSummaryTable() {
+ $("#task-summary-table").find("tr:not(:hidden)").each(function (index) {
+ if (index % 2 == 1) {
+ $(this).css("background-color", "#f9f9f9");
+ } else {
+ $(this).css("background-color", "#ffffff");
+ }
+ });
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 445110d63e184..cdf85bfbf326f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -51,6 +51,11 @@ table.sortable thead {
cursor: pointer;
}
+table.sortable td {
+ word-wrap: break-word;
+ max-width: 600px;
+}
+
.progress {
margin-bottom: 0px; position: relative
}
@@ -115,7 +120,57 @@ pre {
border: none;
}
+.stacktrace-details {
+ max-height: 300px;
+ overflow-y: auto;
+ margin: 0;
+ transition: max-height 0.5s ease-out, padding 0.5s ease-out;
+}
+
+.stacktrace-details.collapsed {
+ max-height: 0;
+ padding-top: 0;
+ padding-bottom: 0;
+ border: none;
+}
+
+span.expand-additional-metrics {
+ cursor: pointer;
+}
+
+span.additional-metric-title {
+ cursor: pointer;
+}
+
+.additional-metrics.collapsed {
+ display: none;
+}
+
.tooltip {
font-weight: normal;
}
+.arrow-open {
+ width: 0;
+ height: 0;
+ border-left: 5px solid transparent;
+ border-right: 5px solid transparent;
+ border-top: 5px solid black;
+ float: left;
+ margin-top: 6px;
+}
+
+.arrow-closed {
+ width: 0;
+ height: 0;
+ border-top: 5px solid transparent;
+ border-bottom: 5px solid transparent;
+ border-left: 5px solid black;
+ display: inline-block;
+}
+
+/* Hide all additional metrics by default. This is done here rather than using JavaScript to
+ * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
+.scheduler_delay, .gc_time, .deserialization_time, .serialization_time, .getting_result_time {
+ display: none;
+}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 12f2fe031cb1d..000bbd6b532ad 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -18,12 +18,14 @@
package org.apache.spark
import java.io.{ObjectInputStream, Serializable}
+import java.util.concurrent.atomic.AtomicLong
import scala.collection.generic.Growable
import scala.collection.mutable.Map
import scala.reflect.ClassTag
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.util.Utils
/**
* A data type that can be accumulated, ie has an commutative and associative "add" operation,
@@ -126,7 +128,7 @@ class Accumulable[R, T] (
}
// Called by Java when deserializing an object
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
value_ = zero
deserialized = true
@@ -227,6 +229,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
*/
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
extends Accumulable[T,T](initialValue, param, name) {
+
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
}
@@ -243,6 +246,36 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] {
}
}
+object AccumulatorParam {
+
+ // The following implicit objects were in SparkContext before 1.2 and users had to
+ // `import SparkContext._` to enable them. Now we move them here to make the compiler find
+ // them automatically. However, as there are duplicate codes in SparkContext for backward
+ // compatibility, please update them accordingly if you modify the following implicit objects.
+
+ implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
+ def addInPlace(t1: Double, t2: Double): Double = t1 + t2
+ def zero(initialValue: Double) = 0.0
+ }
+
+ implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
+ def addInPlace(t1: Int, t2: Int): Int = t1 + t2
+ def zero(initialValue: Int) = 0
+ }
+
+ implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
+ def addInPlace(t1: Long, t2: Long) = t1 + t2
+ def zero(initialValue: Long) = 0L
+ }
+
+ implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
+ def addInPlace(t1: Float, t2: Float) = t1 + t2
+ def zero(initialValue: Float) = 0f
+ }
+
+ // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+}
+
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators {
@@ -251,7 +284,7 @@ private object Accumulators {
val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
var lastId: Long = 0
- def newId: Long = synchronized {
+ def newId(): Long = synchronized {
lastId += 1
lastId
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index f8584b90cabe6..80da62c44edc5 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -61,7 +61,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val computedValues = rdd.computeOrReadCheckpoint(partition, context)
// If the task is running locally, do not persist the result
- if (context.runningLocally) {
+ if (context.isRunningLocally) {
return computedValues
}
@@ -168,8 +168,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
arr.iterator.asInstanceOf[Iterator[T]]
case Right(it) =>
// There is not enough space to cache this partition in memory
- logWarning(s"Not enough space to cache partition $key in memory! " +
- s"Free memory is ${blockManager.memoryStore.freeMemory} bytes.")
val returnValues = it.asInstanceOf[Iterator[T]]
if (putLevel.useDisk) {
logWarning(s"Persisting partition $key to disk instead.")
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
new file mode 100644
index 0000000000000..88adb892998af
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -0,0 +1,516 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable
+
+import org.apache.spark.scheduler._
+
+/**
+ * An agent that dynamically allocates and removes executors based on the workload.
+ *
+ * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If
+ * the scheduler queue is not drained in N seconds, then new executors are added. If the queue
+ * persists for another M seconds, then more executors are added and so on. The number added
+ * in each round increases exponentially from the previous round until an upper bound on the
+ * number of executors has been reached. The upper bound is based both on a configured property
+ * and on the number of tasks pending: the policy will never increase the number of executor
+ * requests past the number needed to handle all pending tasks.
+ *
+ * The rationale for the exponential increase is twofold: (1) Executors should be added slowly
+ * in the beginning in case the number of extra executors needed turns out to be small. Otherwise,
+ * we may add more executors than we need just to remove them later. (2) Executors should be added
+ * quickly over time in case the maximum number of executors is very high. Otherwise, it will take
+ * a long time to ramp up under heavy workloads.
+ *
+ * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not
+ * been scheduled to run any tasks, then it is removed.
+ *
+ * There is no retry logic in either case because we make the assumption that the cluster manager
+ * will eventually fulfill all requests it receives asynchronously.
+ *
+ * The relevant Spark properties include the following:
+ *
+ * spark.dynamicAllocation.enabled - Whether this feature is enabled
+ * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors
+ * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors
+ *
+ * spark.dynamicAllocation.schedulerBacklogTimeout (M) -
+ * If there are backlogged tasks for this duration, add new executors
+ *
+ * spark.dynamicAllocation.sustainedSchedulerBacklogTimeout (N) -
+ * If the backlog is sustained for this duration, add more executors
+ * This is used only after the initial backlog timeout is exceeded
+ *
+ * spark.dynamicAllocation.executorIdleTimeout (K) -
+ * If an executor has been idle for this duration, remove it
+ */
+private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging {
+ import ExecutorAllocationManager._
+
+ private val conf = sc.conf
+
+ // Lower and upper bounds on the number of executors. These are required.
+ private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1)
+ private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1)
+
+ // How long there must be backlogged tasks for before an addition is triggered
+ private val schedulerBacklogTimeout = conf.getLong(
+ "spark.dynamicAllocation.schedulerBacklogTimeout", 60)
+
+ // Same as above, but used only after `schedulerBacklogTimeout` is exceeded
+ private val sustainedSchedulerBacklogTimeout = conf.getLong(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout)
+
+ // How long an executor must be idle for before it is removed
+ private val executorIdleTimeout = conf.getLong(
+ "spark.dynamicAllocation.executorIdleTimeout", 600)
+
+ // During testing, the methods to actually kill and add executors are mocked out
+ private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
+
+ // TODO: The default value of 1 for spark.executor.cores works right now because dynamic
+ // allocation is only supported for YARN and the default number of cores per executor in YARN is
+ // 1, but it might need to be attained differently for different cluster managers
+ private val tasksPerExecutor =
+ conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1)
+
+ validateSettings()
+
+ // Number of executors to add in the next round
+ private var numExecutorsToAdd = 1
+
+ // Number of executors that have been requested but have not registered yet
+ private var numExecutorsPending = 0
+
+ // Executors that have been requested to be removed but have not been killed yet
+ private val executorsPendingToRemove = new mutable.HashSet[String]
+
+ // All known executors
+ private val executorIds = new mutable.HashSet[String]
+
+ // A timestamp of when an addition should be triggered, or NOT_SET if it is not set
+ // This is set when pending tasks are added but not scheduled yet
+ private var addTime: Long = NOT_SET
+
+ // A timestamp for each executor of when the executor should be removed, indexed by the ID
+ // This is set when an executor is no longer running a task, or when it first registers
+ private val removeTimes = new mutable.HashMap[String, Long]
+
+ // Polling loop interval (ms)
+ private val intervalMillis: Long = 100
+
+ // Clock used to schedule when executors should be added and removed
+ private var clock: Clock = new RealClock
+
+ // Listener for Spark events that impact the allocation policy
+ private val listener = new ExecutorAllocationListener(this)
+
+ /**
+ * Verify that the settings specified through the config are valid.
+ * If not, throw an appropriate exception.
+ */
+ private def validateSettings(): Unit = {
+ if (minNumExecutors < 0 || maxNumExecutors < 0) {
+ throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!")
+ }
+ if (minNumExecutors == 0 || maxNumExecutors == 0) {
+ throw new SparkException("spark.dynamicAllocation.{min/max}Executors cannot be 0!")
+ }
+ if (minNumExecutors > maxNumExecutors) {
+ throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " +
+ s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!")
+ }
+ if (schedulerBacklogTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!")
+ }
+ if (sustainedSchedulerBacklogTimeout <= 0) {
+ throw new SparkException(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!")
+ }
+ if (executorIdleTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!")
+ }
+ // Require external shuffle service for dynamic allocation
+ // Otherwise, we may lose shuffle files when killing executors
+ if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) {
+ throw new SparkException("Dynamic allocation of executors requires the external " +
+ "shuffle service. You may enable this through spark.shuffle.service.enabled.")
+ }
+ if (tasksPerExecutor == 0) {
+ throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores")
+ }
+ }
+
+ /**
+ * Use a different clock for this allocation manager. This is mainly used for testing.
+ */
+ def setClock(newClock: Clock): Unit = {
+ clock = newClock
+ }
+
+ /**
+ * Register for scheduler callbacks to decide when to add and remove executors.
+ */
+ def start(): Unit = {
+ sc.addSparkListener(listener)
+ startPolling()
+ }
+
+ /**
+ * Start the main polling thread that keeps track of when to add and remove executors.
+ */
+ private def startPolling(): Unit = {
+ val t = new Thread {
+ override def run(): Unit = {
+ while (true) {
+ try {
+ schedule()
+ } catch {
+ case e: Exception => logError("Exception in dynamic executor allocation thread!", e)
+ }
+ Thread.sleep(intervalMillis)
+ }
+ }
+ }
+ t.setName("spark-dynamic-executor-allocation")
+ t.setDaemon(true)
+ t.start()
+ }
+
+ /**
+ * If the add time has expired, request new executors and refresh the add time.
+ * If the remove time for an existing executor has expired, kill the executor.
+ * This is factored out into its own method for testing.
+ */
+ private def schedule(): Unit = synchronized {
+ val now = clock.getTimeMillis
+ if (addTime != NOT_SET && now >= addTime) {
+ addExecutors()
+ logDebug(s"Starting timer to add more executors (to " +
+ s"expire in $sustainedSchedulerBacklogTimeout seconds)")
+ addTime += sustainedSchedulerBacklogTimeout * 1000
+ }
+
+ removeTimes.foreach { case (executorId, expireTime) =>
+ if (now >= expireTime) {
+ removeExecutor(executorId)
+ removeTimes.remove(executorId)
+ }
+ }
+ }
+
+ /**
+ * Request a number of executors from the cluster manager.
+ * If the cap on the number of executors is reached, give up and reset the
+ * number of executors to add next round instead of continuing to double it.
+ * Return the number actually requested.
+ */
+ private def addExecutors(): Int = synchronized {
+ // Do not request more executors if we have already reached the upper bound
+ val numExistingExecutors = executorIds.size + numExecutorsPending
+ if (numExistingExecutors >= maxNumExecutors) {
+ logDebug(s"Not adding executors because there are already ${executorIds.size} " +
+ s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)")
+ numExecutorsToAdd = 1
+ return 0
+ }
+
+ // The number of executors needed to satisfy all pending tasks is the number of tasks pending
+ // divided by the number of tasks each executor can fit, rounded up.
+ val maxNumExecutorsPending =
+ (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor
+ if (numExecutorsPending >= maxNumExecutorsPending) {
+ logDebug(s"Not adding executors because there are already $numExecutorsPending " +
+ s"pending and pending tasks could only fill $maxNumExecutorsPending")
+ numExecutorsToAdd = 1
+ return 0
+ }
+
+ // It's never useful to request more executors than could satisfy all the pending tasks, so
+ // cap request at that amount.
+ // Also cap request with respect to the configured upper bound.
+ val maxNumExecutorsToAdd = math.min(
+ maxNumExecutorsPending - numExecutorsPending,
+ maxNumExecutors - numExistingExecutors)
+ assert(maxNumExecutorsToAdd > 0)
+
+ val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd)
+
+ val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd
+ val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd)
+ if (addRequestAcknowledged) {
+ logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " +
+ s"tasks are backlogged (new desired total will be $newTotalExecutors)")
+ numExecutorsToAdd =
+ if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1
+ numExecutorsPending += actualNumExecutorsToAdd
+ actualNumExecutorsToAdd
+ } else {
+ logWarning(s"Unable to reach the cluster manager " +
+ s"to request $actualNumExecutorsToAdd executors!")
+ 0
+ }
+ }
+
+ /**
+ * Request the cluster manager to remove the given executor.
+ * Return whether the request is received.
+ */
+ private def removeExecutor(executorId: String): Boolean = synchronized {
+ // Do not kill the executor if we are not aware of it (should never happen)
+ if (!executorIds.contains(executorId)) {
+ logWarning(s"Attempted to remove unknown executor $executorId!")
+ return false
+ }
+
+ // Do not kill the executor again if it is already pending to be killed (should never happen)
+ if (executorsPendingToRemove.contains(executorId)) {
+ logWarning(s"Attempted to remove executor $executorId " +
+ s"when it is already pending to be removed!")
+ return false
+ }
+
+ // Do not kill the executor if we have already reached the lower bound
+ val numExistingExecutors = executorIds.size - executorsPendingToRemove.size
+ if (numExistingExecutors - 1 < minNumExecutors) {
+ logInfo(s"Not removing idle executor $executorId because there are only " +
+ s"$numExistingExecutors executor(s) left (limit $minNumExecutors)")
+ return false
+ }
+
+ // Send a request to the backend to kill this executor
+ val removeRequestAcknowledged = testing || sc.killExecutor(executorId)
+ if (removeRequestAcknowledged) {
+ logInfo(s"Removing executor $executorId because it has been idle for " +
+ s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})")
+ executorsPendingToRemove.add(executorId)
+ true
+ } else {
+ logWarning(s"Unable to reach the cluster manager to kill executor $executorId!")
+ false
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor has been added.
+ */
+ private def onExecutorAdded(executorId: String): Unit = synchronized {
+ if (!executorIds.contains(executorId)) {
+ executorIds.add(executorId)
+ executorIds.foreach(onExecutorIdle)
+ logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})")
+ if (numExecutorsPending > 0) {
+ numExecutorsPending -= 1
+ logDebug(s"Decremented number of pending executors ($numExecutorsPending left)")
+ }
+ } else {
+ logWarning(s"Duplicate executor $executorId has registered")
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor has been removed.
+ */
+ private def onExecutorRemoved(executorId: String): Unit = synchronized {
+ if (executorIds.contains(executorId)) {
+ executorIds.remove(executorId)
+ removeTimes.remove(executorId)
+ logInfo(s"Existing executor $executorId has been removed (new total is ${executorIds.size})")
+ if (executorsPendingToRemove.contains(executorId)) {
+ executorsPendingToRemove.remove(executorId)
+ logDebug(s"Executor $executorId is no longer pending to " +
+ s"be removed (${executorsPendingToRemove.size} left)")
+ }
+ } else {
+ logWarning(s"Unknown executor $executorId has been removed!")
+ }
+ }
+
+ /**
+ * Callback invoked when the scheduler receives new pending tasks.
+ * This sets a time in the future that decides when executors should be added
+ * if it is not already set.
+ */
+ private def onSchedulerBacklogged(): Unit = synchronized {
+ if (addTime == NOT_SET) {
+ logDebug(s"Starting timer to add executors because pending tasks " +
+ s"are building up (to expire in $schedulerBacklogTimeout seconds)")
+ addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000
+ }
+ }
+
+ /**
+ * Callback invoked when the scheduler queue is drained.
+ * This resets all variables used for adding executors.
+ */
+ private def onSchedulerQueueEmpty(): Unit = synchronized {
+ logDebug(s"Clearing timer to add executors because there are no more pending tasks")
+ addTime = NOT_SET
+ numExecutorsToAdd = 1
+ }
+
+ /**
+ * Callback invoked when the specified executor is no longer running any tasks.
+ * This sets a time in the future that decides when this executor should be removed if
+ * the executor is not already marked as idle.
+ */
+ private def onExecutorIdle(executorId: String): Unit = synchronized {
+ if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
+ s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor is now running a task.
+ * This resets all variables used for removing this executor.
+ */
+ private def onExecutorBusy(executorId: String): Unit = synchronized {
+ logDebug(s"Clearing idle timer for $executorId because it is now running a task")
+ removeTimes.remove(executorId)
+ }
+
+ /**
+ * A listener that notifies the given allocation manager of when to add and remove executors.
+ *
+ * This class is intentionally conservative in its assumptions about the relative ordering
+ * and consistency of events returned by the listener. For simplicity, it does not account
+ * for speculated tasks.
+ */
+ private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager)
+ extends SparkListener {
+
+ private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
+ private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
+ private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ synchronized {
+ val stageId = stageSubmitted.stageInfo.stageId
+ val numTasks = stageSubmitted.stageInfo.numTasks
+ stageIdToNumTasks(stageId) = numTasks
+ allocationManager.onSchedulerBacklogged()
+ }
+ }
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ synchronized {
+ val stageId = stageCompleted.stageInfo.stageId
+ stageIdToNumTasks -= stageId
+ stageIdToTaskIndices -= stageId
+
+ // If this is the last stage with pending tasks, mark the scheduler queue as empty
+ // This is needed in case the stage is aborted for any reason
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
+ }
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
+ val stageId = taskStart.stageId
+ val taskId = taskStart.taskInfo.taskId
+ val taskIndex = taskStart.taskInfo.index
+ val executorId = taskStart.taskInfo.executorId
+
+ // If this is the last pending task, mark the scheduler queue as empty
+ stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ val numTasksScheduled = stageIdToTaskIndices(stageId).size
+ val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
+ if (numTasksScheduled == numTasksTotal) {
+ // No more pending tasks for this stage
+ stageIdToNumTasks -= stageId
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
+ }
+
+ // Mark the executor on which this task is scheduled as busy
+ executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
+ allocationManager.onExecutorBusy(executorId)
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ val executorId = taskEnd.taskInfo.executorId
+ val taskId = taskEnd.taskInfo.taskId
+
+ // If the executor is no longer running scheduled any tasks, mark it as idle
+ if (executorIdToTaskIds.contains(executorId)) {
+ executorIdToTaskIds(executorId) -= taskId
+ if (executorIdToTaskIds(executorId).isEmpty) {
+ executorIdToTaskIds -= executorId
+ allocationManager.onExecutorIdle(executorId)
+ }
+ }
+ }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
+ val executorId = blockManagerAdded.blockManagerId.executorId
+ if (executorId != SparkContext.DRIVER_IDENTIFIER) {
+ allocationManager.onExecutorAdded(executorId)
+ }
+ }
+
+ override def onBlockManagerRemoved(
+ blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = {
+ allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId)
+ }
+
+ /**
+ * An estimate of the total number of pending tasks remaining for currently running stages. Does
+ * not account for tasks which may have failed and been resubmitted.
+ */
+ def totalPendingTasks(): Int = {
+ stageIdToNumTasks.map { case (stageId, numTasks) =>
+ numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0)
+ }.sum
+ }
+ }
+
+}
+
+private object ExecutorAllocationManager {
+ val NOT_SET = Long.MaxValue
+}
+
+/**
+ * An abstract clock for measuring elapsed time.
+ */
+private trait Clock {
+ def getTimeMillis: Long
+}
+
+/**
+ * A clock backed by a monotonically increasing time source.
+ * The time returned by this clock does not correspond to any notion of wall-clock time.
+ */
+private class RealClock extends Clock {
+ override def getTimeMillis: Long = System.nanoTime / (1000 * 1000)
+}
+
+/**
+ * A clock that allows the caller to customize the time.
+ * This is used mainly for testing.
+ */
+private class TestClock(startTimeMillis: Long) extends Clock {
+ private var time: Long = startTimeMillis
+ override def getTimeMillis: Long = time
+ def tick(ms: Long): Unit = { time += ms }
+}
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index e8f761eaa5799..e97a7375a267b 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -17,20 +17,21 @@
package org.apache.spark
-import scala.concurrent._
-import scala.concurrent.duration.Duration
-import scala.util.Try
+import java.util.Collections
+import java.util.concurrent.TimeUnit
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaFutureAction
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.{Failure, Try}
+
/**
- * :: Experimental ::
* A future for the result of an action to support cancellation. This is an extension of the
* Scala Future interface to support cancellation.
*/
-@Experimental
trait FutureAction[T] extends Future[T] {
// Note that we redefine methods of the Future trait here explicitly so we can specify a different
// documentation (with reference to the word "action").
@@ -69,6 +70,11 @@ trait FutureAction[T] extends Future[T] {
*/
override def isCompleted: Boolean
+ /**
+ * Returns whether the action has been cancelled.
+ */
+ def isCancelled: Boolean
+
/**
* The value of this Future.
*
@@ -96,15 +102,16 @@ trait FutureAction[T] extends Future[T] {
/**
- * :: Experimental ::
* A [[FutureAction]] holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
-@Experimental
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {
+ @volatile private var _cancelled: Boolean = false
+
override def cancel() {
+ _cancelled = true
jobWaiter.cancel()
}
@@ -143,6 +150,8 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
}
override def isCompleted: Boolean = jobWaiter.jobFinished
+
+ override def isCancelled: Boolean = _cancelled
override def value: Option[Try[T]] = {
if (jobWaiter.jobFinished) {
@@ -164,12 +173,10 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
/**
- * :: Experimental ::
* A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
* action thread if it is being blocked by a job.
*/
-@Experimental
class ComplexFutureAction[T] extends FutureAction[T] {
// Pointer to the thread that is executing the action. It is set when the action is run.
@@ -203,7 +210,11 @@ class ComplexFutureAction[T] extends FutureAction[T] {
} catch {
case e: Exception => p.failure(e)
} finally {
- thread = null
+ // This lock guarantees when calling `thread.interrupt()` in `cancel`,
+ // thread won't be set to null.
+ ComplexFutureAction.this.synchronized {
+ thread = null
+ }
}
}
this
@@ -222,7 +233,7 @@ class ComplexFutureAction[T] extends FutureAction[T] {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
- if (!cancelled) {
+ if (!isCancelled) {
rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
} else {
throw new SparkException("Action has been cancelled")
@@ -243,10 +254,7 @@ class ComplexFutureAction[T] extends FutureAction[T] {
}
}
- /**
- * Returns whether the promise has been cancelled.
- */
- def cancelled: Boolean = _cancelled
+ override def isCancelled: Boolean = _cancelled
@throws(classOf[InterruptedException])
@throws(classOf[scala.concurrent.TimeoutException])
@@ -271,3 +279,55 @@ class ComplexFutureAction[T] extends FutureAction[T] {
def jobIds = jobs
}
+
+private[spark]
+class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T)
+ extends JavaFutureAction[T] {
+
+ import scala.collection.JavaConverters._
+
+ override def isCancelled: Boolean = futureAction.isCancelled
+
+ override def isDone: Boolean = {
+ // According to java.util.Future's Javadoc, this returns True if the task was completed,
+ // whether that completion was due to successful execution, an exception, or a cancellation.
+ futureAction.isCancelled || futureAction.isCompleted
+ }
+
+ override def jobIds(): java.util.List[java.lang.Integer] = {
+ Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava)
+ }
+
+ private def getImpl(timeout: Duration): T = {
+ // This will throw TimeoutException on timeout:
+ Await.ready(futureAction, timeout)
+ futureAction.value.get match {
+ case scala.util.Success(value) => converter(value)
+ case Failure(exception) =>
+ if (isCancelled) {
+ throw new CancellationException("Job cancelled").initCause(exception)
+ } else {
+ // java.util.Future.get() wraps exceptions in ExecutionException
+ throw new ExecutionException("Exception thrown by job", exception)
+ }
+ }
+ }
+
+ override def get(): T = getImpl(Duration.Inf)
+
+ override def get(timeout: Long, unit: TimeUnit): T =
+ getImpl(Duration.fromNanos(unit.toNanos(timeout)))
+
+ override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized {
+ if (isDone) {
+ // According to java.util.Future's Javadoc, this should return false if the task is completed.
+ false
+ } else {
+ // We're limited in terms of the semantics we can provide here; our cancellation is
+ // asynchronous and doesn't provide a mechanism to not cancel if the job is running.
+ futureAction.cancel()
+ true
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4cb0bd4142435..7d96962c4acd7 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
+ logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
@@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
new ConcurrentHashMap[Int, Array[MapStatus]]
}
-private[spark] object MapOutputTracker {
+private[spark] object MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
@@ -381,6 +382,7 @@ private[spark] object MapOutputTracker {
statuses.map {
status =>
if (status == null) {
+ logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 37053bb6f37ad..e53a78ead2c0e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -204,7 +204,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
@throws(classOf[IOException])
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
@@ -222,7 +222,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
@throws(classOf[IOException])
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7b2377e..dbff9d12b5ad7 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication}
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
/**
* Spark class responsible for security.
@@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Authenticator installed in the SecurityManager to how it does the authentication
* and in this case gets the user name and password from the request.
*
- * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
* exchange messages. For this we use the Java SASL
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
* as the authentication mechanism. This means the shared secret is not passed
@@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
- * Since the connectionManager does asynchronous messages passing, the SASL
+ * Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
@@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and waits for the response from the server and does the handshake before sending
* the real message.
*
+ * The NettyBlockTransferService ensures that SASL authentication is performed
+ * synchronously prior to any other communication on a connection. This is done in
+ * SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
* properly. For non-Yarn deployments, users can write a filter to go through a
@@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* can take place.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
@@ -337,4 +342,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
* @return the secret key as a String if authentication is enabled, otherwise returns null
*/
def getSecretKey(): String = secretKey
+
+ // Default SecurityManager only has a single secret key, so ignore appId.
+ override def getSaslUser(appId: String): String = getSaslUser()
+ override def getSecretKey(appId: String): String = getSecretKey()
}
diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
index e50b9ac2291f9..55cb25946c2ad 100644
--- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala
+++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
@@ -24,18 +24,19 @@ import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.Writable
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
@DeveloperApi
class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {
def value = t
override def toString = t.toString
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.defaultWriteObject()
new ObjectWritable(t).write(out)
}
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
val ow = new ObjectWritable()
ow.setConf(new Configuration())
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 605df0e929faa..4c6c86c7bad78 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -18,7 +18,8 @@
package org.apache.spark
import scala.collection.JavaConverters._
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap, LinkedHashSet}
+import org.apache.spark.serializer.KryoSerializer
/**
* Configuration for a Spark application. Used to set various Spark parameters as key-value pairs.
@@ -140,6 +141,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
this
}
+ /**
+ * Use Kryo serialization and register the given set of classes with Kryo.
+ * If called multiple times, this will append the classes from all calls together.
+ */
+ def registerKryoClasses(classes: Array[Class[_]]): SparkConf = {
+ val allClassNames = new LinkedHashSet[String]()
+ allClassNames ++= get("spark.kryo.classesToRegister", "").split(',').filter(!_.isEmpty)
+ allClassNames ++= classes.map(_.getName)
+
+ set("spark.kryo.classesToRegister", allClassNames.mkString(","))
+ set("spark.serializer", classOf[KryoSerializer].getName)
+ this
+ }
+
/** Remove a parameter from the configuration */
def remove(key: String): SparkConf = {
settings.remove(key)
@@ -202,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
getAll.filter { case (k, _) => isAkkaConf(k) }
+ /**
+ * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+ * from the start in the Executor.
+ */
+ def getAppId: String = get("spark.app.id")
+
/** Does the configuration contain a given parameter? */
def contains(key: String): Boolean = settings.contains(key)
@@ -229,6 +250,19 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
val executorClasspathKey = "spark.executor.extraClassPath"
val driverOptsKey = "spark.driver.extraJavaOptions"
val driverClassPathKey = "spark.driver.extraClassPath"
+ val driverLibraryPathKey = "spark.driver.extraLibraryPath"
+
+ // Used by Yarn in 1.1 and before
+ sys.props.get("spark.driver.libraryPath").foreach { value =>
+ val warning =
+ s"""
+ |spark.driver.libraryPath was detected (set to '$value').
+ |This is deprecated in Spark 1.2+.
+ |
+ |Please instead use: $driverLibraryPathKey
+ """.stripMargin
+ logWarning(warning)
+ }
// Validate spark.executor.extraJavaOptions
settings.get(executorOptsKey).map { javaOpts =>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 97109b9f41b60..9b0d5be7a7ab2 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,8 +21,8 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
+import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.JavaConversions._
@@ -41,7 +41,8 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
-import org.apache.spark.input.WholeTextFileInputFormat
+import org.apache.spark.executor.TriggerThreadDump
+import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
@@ -49,24 +50,41 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
+import org.apache.spark.ui.{SparkUI, ConsoleProgressBar}
+import org.apache.spark.ui.jobs.JobProgressListener
+import org.apache.spark.util._
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
+ *
* @param config a Spark Config object describing the application configuration. Any settings in
* this config overrides the default configs as well as system properties.
*/
-
class SparkContext(config: SparkConf) extends Logging {
+ // The call site where this SparkContext was constructed.
+ private val creationSite: CallSite = Utils.getCallSite()
+
+ // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active
+ private val allowMultipleContexts: Boolean =
+ config.getBoolean("spark.driver.allowMultipleContexts", false)
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having started construction.
+ // NOTE: this must be placed at the beginning of the SparkContext constructor.
+ SparkContext.markPartiallyConstructed(this, allowMultipleContexts)
+
// This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
// etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
// contains a map from hostname to a list of input format splits on the host.
private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()
+ val startTime = System.currentTimeMillis()
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -208,16 +226,10 @@ class SparkContext(config: SparkConf) extends Logging {
// An asynchronous listener bus for Spark events
private[spark] val listenerBus = new LiveListenerBus
- // Create the Spark execution environment (cache, map output tracker, etc)
conf.set("spark.executor.id", "driver")
- private[spark] val env = SparkEnv.create(
- conf,
- "",
- conf.get("spark.driver.host"),
- conf.get("spark.driver.port").toInt,
- isDriver = true,
- isLocal = isLocal,
- listenerBus = listenerBus)
+
+ // Create the Spark execution environment (cache, map output tracker, etc)
+ private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
@@ -229,21 +241,36 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
- // Initialize the Spark UI, registering all associated listeners
+
+ private[spark] val jobProgressListener = new JobProgressListener(conf)
+ listenerBus.addListener(jobProgressListener)
+
+ val statusTracker = new SparkStatusTracker(this)
+
+ private[spark] val progressBar: Option[ConsoleProgressBar] =
+ if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) {
+ Some(new ConsoleProgressBar(this))
+ } else {
+ None
+ }
+
+ // Initialize the Spark UI
private[spark] val ui: Option[SparkUI] =
if (conf.getBoolean("spark.ui.enabled", true)) {
- Some(new SparkUI(this))
+ Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener,
+ env.securityManager,appName))
} else {
// For tests, do not enable the UI
None
}
+
+ // Bind the UI before starting the task scheduler to communicate
+ // the bound port to the cluster manager properly
ui.foreach(_.bind())
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf)
- val startTime = System.currentTimeMillis()
-
// Add each JAR given through the constructor
if (jars != null) {
jars.foreach(addJar)
@@ -291,7 +318,8 @@ class SparkContext(config: SparkConf) extends Logging {
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
- private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master)
+ private[spark] var (schedulerBackend, taskScheduler) =
+ SparkContext.createTaskScheduler(this, master)
private val heartbeatReceiver = env.actorSystem.actorOf(
Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver")
@volatile private[spark] var dagScheduler: DAGScheduler = _
@@ -309,6 +337,8 @@ class SparkContext(config: SparkConf) extends Logging {
val applicationId: String = taskScheduler.applicationId()
conf.set("spark.app.id", applicationId)
+ env.blockManager.initialize(applicationId)
+
val metricsSystem = env.metricsSystem
// The metrics system for Driver need to be set spark.app.id to app ID.
@@ -326,6 +356,15 @@ class SparkContext(config: SparkConf) extends Logging {
} else None
}
+ // Optionally scale number of executors dynamically based on workload. Exposed for testing.
+ private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] =
+ if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) {
+ Some(new ExecutorAllocationManager(this))
+ } else {
+ None
+ }
+ executorAllocationManager.foreach(_.start())
+
// At this point, all relevant SparkListeners have been registered, so begin releasing events
listenerBus.start()
@@ -348,6 +387,29 @@ class SparkContext(config: SparkConf) extends Logging {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ /**
+ * Called by the web UI to obtain executor thread dumps. This method may be expensive.
+ * Logs an error and returns None if we failed to obtain a thread dump, which could occur due
+ * to an executor being dead or unresponsive or due to network issues while sending the thread
+ * dump message back to the driver.
+ */
+ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
+ try {
+ if (executorId == SparkContext.DRIVER_IDENTIFIER) {
+ Some(Utils.getThreadDump())
+ } else {
+ val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
+ val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
+ Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
+ AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception getting thread dump from executor $executorId", e)
+ None
+ }
+ }
+
private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
@@ -520,6 +582,73 @@ class SparkContext(config: SparkConf) extends Logging {
minPartitions).setName(path)
}
+
+ /**
+ * :: Experimental ::
+ *
+ * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
+ * (useful for binary data)
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ *
+ * @note Small files are preferred; very large files may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
+ RDD[(String, PortableDataStream)] = {
+ val job = new NewHadoopJob(hadoopConfiguration)
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val updateConf = job.getConfiguration
+ new BinaryFileRDD(
+ this,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ updateConf,
+ minPartitions).setName(path)
+ }
+
+ /**
+ * :: Experimental ::
+ *
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @param recordLength The length at which to split the records
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
+ : RDD[Array[Byte]] = {
+ conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
+ val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
+ classOf[FixedLengthBinaryInputFormat],
+ classOf[LongWritable],
+ classOf[BytesWritable],
+ conf=conf)
+ val data = br.map{ case (k, v) => v.getBytes}
+ data
+ }
+
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
* necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
@@ -779,20 +908,20 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values
* with `+=`. Only the driver can access the accumuable's `value`.
- * @tparam T accumulator type
- * @tparam R type that can be added to the accumulator
+ * @tparam R accumulator result type
+ * @tparam T type that can be added to the accumulator
*/
- def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
+ def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param)
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the
* Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can
* access the accumuable's `value`.
- * @tparam T accumulator type
- * @tparam R type that can be added to the accumulator
+ * @tparam R accumulator result type
+ * @tparam T type that can be added to the accumulator
*/
- def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) =
+ def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param, Some(name))
/**
@@ -814,6 +943,8 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
+ val callSite = getCallSite
+ logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}
@@ -831,11 +962,12 @@ class SparkContext(config: SparkConf) extends Logging {
case "local" => "file:" + uri.getPath
case _ => path
}
- addedFiles(key) = System.currentTimeMillis
+ val timestamp = System.currentTimeMillis
+ addedFiles(key) = timestamp
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager,
- hadoopConfiguration)
+ hadoopConfiguration, timestamp, useCache = false)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
postEnvironmentUpdate()
@@ -850,6 +982,46 @@ class SparkContext(config: SparkConf) extends Logging {
listenerBus.addListener(listener)
}
+ /**
+ * :: DeveloperApi ::
+ * Request an additional number of executors from the cluster manager.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
+ */
+ @DeveloperApi
+ def requestExecutors(numAdditionalExecutors: Int): Boolean = {
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.requestExecutors(numAdditionalExecutors)
+ case _ =>
+ logWarning("Requesting executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Request that the cluster manager kill the specified executors.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
+ */
+ @DeveloperApi
+ def killExecutors(executorIds: Seq[String]): Boolean = {
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.killExecutors(executorIds)
+ case _ =>
+ logWarning("Killing executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Request that cluster manager the kill the specified executor.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
+ */
+ @DeveloperApi
+ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
+
/** The version of Spark on which this application is running. */
def version = SPARK_VERSION
@@ -1015,27 +1187,30 @@ class SparkContext(config: SparkConf) extends Logging {
/** Shut down the SparkContext. */
def stop() {
- postApplicationEnd()
- ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
- env.metricsSystem.report()
- metadataCleaner.cancel()
- env.actorSystem.stop(heartbeatReceiver)
- cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- SparkEnv.set(null)
- listenerBus.stop()
- eventLogger.foreach(_.stop())
- logInfo("Successfully stopped SparkContext")
- } else {
- logInfo("SparkContext already stopped")
+ SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ postApplicationEnd()
+ ui.foreach(_.stop())
+ // Do this only if not stopped already - best case effort.
+ // prevent NPE if stopped more than once.
+ val dagSchedulerCopy = dagScheduler
+ dagScheduler = null
+ if (dagSchedulerCopy != null) {
+ env.metricsSystem.report()
+ metadataCleaner.cancel()
+ env.actorSystem.stop(heartbeatReceiver)
+ cleaner.foreach(_.stop())
+ dagSchedulerCopy.stop()
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ SparkEnv.set(null)
+ listenerBus.stop()
+ eventLogger.foreach(_.stop())
+ logInfo("Successfully stopped SparkContext")
+ SparkContext.clearActiveContext()
+ } else {
+ logInfo("SparkContext already stopped")
+ }
}
}
@@ -1106,6 +1281,7 @@ class SparkContext(config: SparkConf) extends Logging {
logInfo("Starting job: " + callSite.shortForm)
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
+ progressBar.foreach(_.finishAll())
rdd.doCheckpoint()
}
@@ -1324,6 +1500,11 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having finished construction.
+ // NOTE: this must be placed at the end of the SparkContext constructor.
+ SparkContext.setActiveContext(this, allowMultipleContexts)
}
/**
@@ -1332,6 +1513,107 @@ class SparkContext(config: SparkConf) extends Logging {
*/
object SparkContext extends Logging {
+ /**
+ * Lock that guards access to global variables that track SparkContext construction.
+ */
+ private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object()
+
+ /**
+ * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var activeContext: Option[SparkContext] = None
+
+ /**
+ * Points to a partially-constructed SparkContext if some thread is in the SparkContext
+ * constructor, or `None` if no SparkContext is being constructed.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var contextBeingConstructed: Option[SparkContext] = None
+
+ /**
+ * Called to ensure that no other SparkContext is running in this JVM.
+ *
+ * Throws an exception if a running context is detected and logs a warning if another thread is
+ * constructing a SparkContext. This warning is necessary because the current locking scheme
+ * prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private def assertNoOtherContextIsRunning(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ contextBeingConstructed.foreach { otherContext =>
+ if (otherContext ne sc) { // checks for reference equality
+ // Since otherContext might point to a partially-constructed context, guard against
+ // its creationSite field being null:
+ val otherContextCreationSite =
+ Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location")
+ val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" +
+ " constructor). This may indicate an error, since only one SparkContext may be" +
+ " running in this JVM (see SPARK-2243)." +
+ s" The other SparkContext was created at:\n$otherContextCreationSite"
+ logWarning(warnMsg)
+ }
+
+ activeContext.foreach { ctx =>
+ val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." +
+ " To ignore this error, set spark.driver.allowMultipleContexts = true. " +
+ s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}"
+ val exception = new SparkException(errMsg)
+ if (allowMultipleContexts) {
+ logWarning("Multiple running SparkContexts detected in the same JVM!", exception)
+ } else {
+ throw exception
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is
+ * running. Throws an exception if a running context is detected and logs a warning if another
+ * thread is constructing a SparkContext. This warning is necessary because the current locking
+ * scheme prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private[spark] def markPartiallyConstructed(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = Some(sc)
+ }
+ }
+
+ /**
+ * Called at the end of the SparkContext constructor to ensure that no other SparkContext has
+ * raced with this constructor and started.
+ */
+ private[spark] def setActiveContext(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = None
+ activeContext = Some(sc)
+ }
+ }
+
+ /**
+ * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's
+ * also called in unit tests to prevent a flood of warnings from test suites that don't / can't
+ * properly clean up their SparkContexts.
+ */
+ private[spark] def clearActiveContext(): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ activeContext = None
+ }
+ }
+
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
@@ -1340,47 +1622,76 @@ object SparkContext extends Logging {
private[spark] val SPARK_UNKNOWN_USER = ""
- implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
+ private[spark] val DRIVER_IDENTIFIER = ""
+
+ // The following deprecated objects have already been copied to `object AccumulatorParam` to
+ // make the compiler find them automatically. They are duplicate codes only for backward
+ // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the
+ // following ones.
+
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
}
- implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ object IntAccumulatorParam extends AccumulatorParam[Int] {
def addInPlace(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0
}
- implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ object LongAccumulatorParam extends AccumulatorParam[Long] {
def addInPlace(t1: Long, t2: Long) = t1 + t2
def zero(initialValue: Long) = 0L
}
- implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ object FloatAccumulatorParam extends AccumulatorParam[Float] {
def addInPlace(t1: Float, t2: Float) = t1 + t2
def zero(initialValue: Float) = 0f
}
- // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+ // The following deprecated functions have already been moved to `object RDD` to
+ // make the compiler find them automatically. They are still kept here for backward compatibility
+ // and just call the corresponding functions in `object RDD`.
- implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.2.0")
+ def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
- new PairRDDFunctions(rdd)
+ RDD.rddToPairRDDFunctions(rdd)
}
- implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.2.0")
+ def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd)
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.2.0")
+ def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
rdd: RDD[(K, V)]) =
- new SequenceFileRDDFunctions(rdd)
+ RDD.rddToSequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.2.0")
+ def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](
rdd: RDD[(K, V)]) =
- new OrderedRDDFunctions[K, V, (K, V)](rdd)
+ RDD.rddToOrderedRDDFunctions(rdd)
- implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.2.0")
+ def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd)
- implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
- new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.2.0")
+ def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
+ RDD.numericRDDToDoubleRDDFunctions(rdd)
// Implicit conversions to common Writable types, for saveAsSequenceFile
@@ -1406,37 +1717,49 @@ object SparkContext extends Logging {
arr.map(x => anyToWritable(x)).toArray)
}
- // Helper objects for converting common types to Writable
- private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T)
- : WritableConverter[T] = {
- val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
- new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
- }
+ // The following deprecated functions have already been moved to `object WritableConverter` to
+ // make the compiler find them automatically. They are still kept here for backward compatibility
+ // and just call the corresponding functions in `object WritableConverter`.
- implicit def intWritableConverter(): WritableConverter[Int] =
- simpleWritableConverter[Int, IntWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def intWritableConverter(): WritableConverter[Int] =
+ WritableConverter.intWritableConverter()
- implicit def longWritableConverter(): WritableConverter[Long] =
- simpleWritableConverter[Long, LongWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def longWritableConverter(): WritableConverter[Long] =
+ WritableConverter.longWritableConverter()
- implicit def doubleWritableConverter(): WritableConverter[Double] =
- simpleWritableConverter[Double, DoubleWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def doubleWritableConverter(): WritableConverter[Double] =
+ WritableConverter.doubleWritableConverter()
- implicit def floatWritableConverter(): WritableConverter[Float] =
- simpleWritableConverter[Float, FloatWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def floatWritableConverter(): WritableConverter[Float] =
+ WritableConverter.floatWritableConverter()
- implicit def booleanWritableConverter(): WritableConverter[Boolean] =
- simpleWritableConverter[Boolean, BooleanWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def booleanWritableConverter(): WritableConverter[Boolean] =
+ WritableConverter.booleanWritableConverter()
- implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
- simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
- }
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def bytesWritableConverter(): WritableConverter[Array[Byte]] =
+ WritableConverter.bytesWritableConverter()
- implicit def stringWritableConverter(): WritableConverter[String] =
- simpleWritableConverter[String, Text](_.toString)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def stringWritableConverter(): WritableConverter[String] =
+ WritableConverter.stringWritableConverter()
- implicit def writableWritableConverter[T <: Writable]() =
- new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.2.0")
+ def writableWritableConverter[T <: Writable]() =
+ WritableConverter.writableWritableConverter()
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
@@ -1492,8 +1815,13 @@ object SparkContext extends Logging {
res
}
- /** Creates a task scheduler based on a given master URL. Extracted for testing. */
- private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = {
+ /**
+ * Create a task scheduler based on a given master URL.
+ * Return a 2-tuple of the scheduler backend and the task scheduler.
+ */
+ private def createTaskScheduler(
+ sc: SparkContext,
+ master: String): (SchedulerBackend, TaskScheduler) = {
// Regular expression used for local[N] and local[*] master formats
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -1515,16 +1843,19 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalBackend(scheduler, 1)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_N_REGEX(threads) =>
def localCpuCount = Runtime.getRuntime.availableProcessors()
// local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
val threadCount = if (threads == "*") localCpuCount else threads.toInt
+ if (threadCount <= 0) {
+ throw new SparkException(s"Asked to run locally with $threadCount threads")
+ }
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
def localCpuCount = Runtime.getRuntime.availableProcessors()
@@ -1534,14 +1865,14 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
@@ -1561,7 +1892,7 @@ object SparkContext extends Logging {
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
- scheduler
+ (backend, scheduler)
case "yarn-standalone" | "yarn-cluster" =>
if (master == "yarn-standalone") {
@@ -1590,7 +1921,7 @@ object SparkContext extends Logging {
}
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case "yarn-client" =>
val scheduler = try {
@@ -1617,7 +1948,7 @@ object SparkContext extends Logging {
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case mesosUrl @ MESOS_REGEX(_) =>
MesosNativeLibrary.load()
@@ -1630,13 +1961,13 @@ object SparkContext extends Logging {
new MesosSchedulerBackend(scheduler, sc, url)
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case SIMR_REGEX(simrUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case _ =>
throw new SparkException("Could not parse Master URL: '" + master + "'")
@@ -1655,3 +1986,46 @@ private[spark] class WritableConverter[T](
val writableClass: ClassTag[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable
+
+object WritableConverter {
+
+ // Helper objects for converting common types to Writable
+ private[spark] def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T)
+ : WritableConverter[T] = {
+ val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
+ new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
+ }
+
+ // The following implicit functions were in SparkContext before 1.2 and users had to
+ // `import SparkContext._` to enable them. Now we move them here to make the compiler find
+ // them automatically. However, we still keep the old functions in SparkContext for backward
+ // compatibility and forward to the following functions directly.
+
+ implicit def intWritableConverter(): WritableConverter[Int] =
+ simpleWritableConverter[Int, IntWritable](_.get)
+
+ implicit def longWritableConverter(): WritableConverter[Long] =
+ simpleWritableConverter[Long, LongWritable](_.get)
+
+ implicit def doubleWritableConverter(): WritableConverter[Double] =
+ simpleWritableConverter[Double, DoubleWritable](_.get)
+
+ implicit def floatWritableConverter(): WritableConverter[Float] =
+ simpleWritableConverter[Float, FloatWritable](_.get)
+
+ implicit def booleanWritableConverter(): WritableConverter[Boolean] =
+ simpleWritableConverter[Boolean, BooleanWritable](_.get)
+
+ implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
+ simpleWritableConverter[Array[Byte], BytesWritable](bw =>
+ // getBytes method returns array which is longer then data to be returned
+ Arrays.copyOfRange(bw.getBytes, 0, bw.getLength)
+ )
+ }
+
+ implicit def stringWritableConverter(): WritableConverter[String] =
+ simpleWritableConverter[String, Text](_.toString)
+
+ implicit def writableWritableConverter[T <: Writable]() =
+ new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 72cac42cd2b2b..e464b32e61dd6 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
@@ -43,9 +44,8 @@ import org.apache.spark.util.{AkkaUtils, Utils}
* :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
- * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
- * objects needs to have the right SparkEnv set. You can get the current environment with
- * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
+ * Spark code finds the SparkEnv through a global variable, so all the threads can access the same
+ * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext).
*
* NOTE: This is not intended for external use. This is exposed for Shark and may be made private
* in a future release.
@@ -69,6 +69,7 @@ class SparkEnv (
val shuffleMemoryManager: ShuffleMemoryManager,
val conf: SparkConf) extends Logging {
+ private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -76,6 +77,7 @@ class SparkEnv (
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
private[spark] def stop() {
+ isStopped = true
pythonWorkers.foreach { case(key, worker) => worker.stop() }
Option(httpFileServer).foreach(_.stop())
mapOutputTracker.stop()
@@ -119,40 +121,73 @@ class SparkEnv (
}
object SparkEnv extends Logging {
- private val env = new ThreadLocal[SparkEnv]
- @volatile private var lastSetSparkEnv : SparkEnv = _
+ @volatile private var env: SparkEnv = _
private[spark] val driverActorSystemName = "sparkDriver"
private[spark] val executorActorSystemName = "sparkExecutor"
def set(e: SparkEnv) {
- lastSetSparkEnv = e
- env.set(e)
+ env = e
}
/**
- * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv
- * previously set in any thread.
+ * Returns the SparkEnv.
*/
def get: SparkEnv = {
- Option(env.get()).getOrElse(lastSetSparkEnv)
+ env
}
/**
* Returns the ThreadLocal SparkEnv.
*/
+ @deprecated("Use SparkEnv.get instead", "1.2")
def getThreadLocal: SparkEnv = {
- env.get()
+ env
}
- private[spark] def create(
+ /**
+ * Create a SparkEnv for the driver.
+ */
+ private[spark] def createDriverEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
+ assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
+ val hostname = conf.get("spark.driver.host")
+ val port = conf.get("spark.driver.port").toInt
+ create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus)
+ }
+
+ /**
+ * Create a SparkEnv for an executor.
+ * In coarse-grained mode, the executor provides an actor system that is already instantiated.
+ */
+ private[spark] def createExecutorEnv(
+ conf: SparkConf,
+ executorId: String,
+ hostname: String,
+ port: Int,
+ numCores: Int,
+ isLocal: Boolean,
+ actorSystem: ActorSystem = null): SparkEnv = {
+ create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem,
+ numUsableCores = numCores)
+ }
+
+ /**
+ * Helper method to create a SparkEnv for a driver or an executor.
+ */
+ private def create(
conf: SparkConf,
executorId: String,
hostname: String,
port: Int,
isDriver: Boolean,
isLocal: Boolean,
- listenerBus: LiveListenerBus = null): SparkEnv = {
+ listenerBus: LiveListenerBus = null,
+ defaultActorSystem: ActorSystem = null,
+ numUsableCores: Int = 0): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
@@ -160,9 +195,16 @@ object SparkEnv extends Logging {
}
val securityManager = new SecurityManager(conf)
- val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- actorSystemName, hostname, port, conf, securityManager)
+
+ // If an existing actor system is already provided, use it.
+ // This is the case when an executor is launched in coarse-grained mode.
+ val (actorSystem, boundPort) =
+ Option(defaultActorSystem) match {
+ case Some(as) => (as, port)
+ case None =>
+ val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
+ AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
+ }
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
// This is so that we tell the executors the correct port to connect to.
@@ -234,14 +276,22 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
- val blockTransferService = new NioBlockTransferService(conf, securityManager)
+ val blockTransferService =
+ conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
+ case "netty" =>
+ new NettyBlockTransferService(conf, securityManager, numUsableCores)
+ case "nio" =>
+ new NioBlockTransferService(conf, securityManager)
+ }
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ // NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
+ serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager,
+ numUsableCores)
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 376e69cd997d5..40237596570de 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index 65003b6ac6a0a..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import java.io.IOException
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Used to respond to server's counterpart, SaslServer with SASL tokens
- * represented as byte arrays.
- *
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
- null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslClientCallbackHandler(securityMgr))
-
- /**
- * Used to initiate SASL handshake with server.
- * @return response to challenge if needed
- */
- def firstToken(): Array[Byte] = {
- synchronized {
- val saslToken: Array[Byte] =
- if (saslClient != null && saslClient.hasInitialResponse()) {
- logDebug("has initial response")
- saslClient.evaluateChallenge(new Array[Byte](0))
- } else {
- new Array[Byte](0)
- }
- saslToken
- }
- }
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslClient != null) saslClient.isComplete() else false
- }
- }
-
- /**
- * Respond to server's SASL token.
- * @param saslTokenMessage contains server's SASL token
- * @return client's response SASL token
- */
- def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslClient might be using.
- */
- def dispose() {
- synchronized {
- if (saslClient != null) {
- try {
- saslClient.dispose()
- } catch {
- case e: SaslException => // ignored
- } finally {
- saslClient = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * that works with share secrets.
- */
- private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
- CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
- private val secretKey = securityMgr.getSecretKey()
- private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8"))
-
- /**
- * Implementation used to respond to SASL request from the server.
- *
- * @param callbacks objects that indicate what credential information the
- * server's SaslServer requires from the client.
- */
- override def handle(callbacks: Array[Callback]) {
- logDebug("in the sasl client callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL client callback: setting username: " + userName)
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL client callback: setting userPassword")
- pc.setPassword(userPassword)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case cb: RealmChoiceCallback => {}
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index f6b0a9132aca4..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Actual SASL work done by this object from javax.security.sasl.
- */
- private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
- SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslDigestCallbackHandler(securityMgr))
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslServer != null) saslServer.isComplete() else false
- }
- }
-
- /**
- * Used to respond to server SASL tokens.
- * @param token Server's SASL token
- * @return response to send back to the server.
- */
- def response(token: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslServer might be using.
- */
- def dispose() {
- synchronized {
- if (saslServer != null) {
- try {
- saslServer.dispose()
- } catch {
- case e: SaslException => // ignore
- } finally {
- saslServer = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * for SASL DIGEST-MD5 mechanism
- */
- private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
- extends CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
-
- override def handle(callbacks: Array[Callback]) {
- logDebug("In the sasl server callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL server callback: setting username")
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL server callback: setting userPassword")
- val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8"))
- pc.setPassword(password)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case ac: AuthorizeCallback => {
- val authid = ac.getAuthenticationID()
- val authzid = ac.getAuthorizationID()
- if (authid.equals(authzid)) {
- logDebug("set auth to true")
- ac.setAuthorized(true)
- } else {
- logDebug("set auth to false")
- ac.setAuthorized(false)
- }
- if (ac.isAuthorized()) {
- logDebug("sasl server is authorized")
- ac.setAuthorizedID(authzid)
- }
- }
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
- }
- }
- }
-}
-
-private[spark] object SparkSaslServer {
-
- /**
- * This is passed as the server name when creating the sasl client/server.
- * This could be changed to be configurable in the future.
- */
- val SASL_DEFAULT_REALM = "default"
-
- /**
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- val DIGEST = "DIGEST-MD5"
-
- /**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
- */
- val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
- /**
- * Encode a byte[] identifier as a Base64-encoded string.
- *
- * @param identifier identifier to encode
- * @return Base64-encoded string
- */
- def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), "utf-8")
- }
-
- /**
- * Encode a password as a base64-encoded char[] array.
- * @param password as a byte array.
- * @return password as a char array.
- */
- def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), "utf-8").toCharArray()
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
new file mode 100644
index 0000000000000..edbdda8a0bcb6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Low-level status reporting APIs for monitoring job and stage progress.
+ *
+ * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should
+ * be prepared to handle empty / missing information. For example, a job's stage ids may be known
+ * but the status API may not have any information about the details of those stages, so
+ * `getStageInfo` could potentially return `None` for a valid stage id.
+ *
+ * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs
+ * will provide information for the last `spark.ui.retainedStages` stages and
+ * `spark.ui.retainedJobs` jobs.
+ *
+ * NOTE: this class's constructor should be considered private and may be subject to change.
+ */
+class SparkStatusTracker private[spark] (sc: SparkContext) {
+
+ private val jobProgressListener = sc.jobProgressListener
+
+ /**
+ * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then
+ * returns all known jobs that are not associated with a job group.
+ *
+ * The returned list may contain running, failed, and completed jobs, and may vary across
+ * invocations of this method. This method does not guarantee the order of the elements in
+ * its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = {
+ jobProgressListener.synchronized {
+ val jobData = jobProgressListener.jobIdToData.valuesIterator
+ jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray
+ }
+ }
+
+ /**
+ * Returns an array containing the ids of all active stages.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveStageIds(): Array[Int] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.activeStages.values.map(_.stageId).toArray
+ }
+ }
+
+ /**
+ * Returns an array containing the ids of all active jobs.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveJobIds(): Array[Int] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.activeJobs.values.map(_.jobId).toArray
+ }
+ }
+
+ /**
+ * Returns job information, or `None` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): Option[SparkJobInfo] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.jobIdToData.get(jobId).map { data =>
+ new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status)
+ }
+ }
+ }
+
+ /**
+ * Returns stage information, or `None` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): Option[SparkStageInfo] = {
+ jobProgressListener.synchronized {
+ for (
+ info <- jobProgressListener.stageIdToInfo.get(stageId);
+ data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId))
+ ) yield {
+ new SparkStageInfoImpl(
+ stageId,
+ info.attemptId,
+ info.submissionTime.getOrElse(0),
+ info.name,
+ info.numTasks,
+ data.numActiveTasks,
+ data.numCompleteTasks,
+ data.numFailedTasks)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
new file mode 100644
index 0000000000000..e5c7c8d0db578
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+private class SparkJobInfoImpl (
+ val jobId: Int,
+ val stageIds: Array[Int],
+ val status: JobExecutionStatus)
+ extends SparkJobInfo
+
+private class SparkStageInfoImpl(
+ val stageId: Int,
+ val currentAttemptId: Int,
+ val submissionTime: Long,
+ val name: String,
+ val numTasks: Int,
+ val numActiveTasks: Int,
+ val numCompletedTasks: Int,
+ val numFailedTasks: Int)
+ extends SparkStageInfo
diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
new file mode 100644
index 0000000000000..4636c4600a01a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * This class exists to restrict the visibility of TaskContext setters.
+ */
+private [spark] object TaskContextHelper {
+
+ def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc)
+
+ def unset(): Unit = TaskContext.unset()
+
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
new file mode 100644
index 0000000000000..afd2b85d33a77
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
+
+import scala.collection.mutable.ArrayBuffer
+
+private[spark] class TaskContextImpl(val stageId: Int,
+ val partitionId: Int,
+ val attemptId: Long,
+ val runningLocally: Boolean = false,
+ val taskMetrics: TaskMetrics = TaskMetrics.empty)
+ extends TaskContext
+ with Logging {
+
+ // List of callback functions to execute when the task completes.
+ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
+
+ // Whether the corresponding task has been killed.
+ @volatile private var interrupted: Boolean = false
+
+ // Whether the task has completed.
+ @volatile private var completed: Boolean = false
+
+ override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
+ onCompleteCallbacks += listener
+ this
+ }
+
+ override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f(context)
+ }
+ this
+ }
+
+ @deprecated("use addTaskCompletionListener", "1.1.0")
+ override def addOnCompleteCallback(f: () => Unit) {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f()
+ }
+ }
+
+ /** Marks the task as completed and triggers the listeners. */
+ private[spark] def markTaskCompleted(): Unit = {
+ completed = true
+ val errorMsgs = new ArrayBuffer[String](2)
+ // Process complete callbacks in the reverse order of registration
+ onCompleteCallbacks.reverse.foreach { listener =>
+ try {
+ listener.onTaskCompletion(this)
+ } catch {
+ case e: Throwable =>
+ errorMsgs += e.getMessage
+ logError("Error in TaskCompletionListener", e)
+ }
+ }
+ if (errorMsgs.nonEmpty) {
+ throw new TaskCompletionListenerException(errorMsgs)
+ }
+ }
+
+ /** Marks the task for interruption, i.e. cancellation. */
+ private[spark] def markInterrupted(): Unit = {
+ interrupted = true
+ }
+
+ override def isCompleted: Boolean = completed
+
+ override def isRunningLocally: Boolean = runningLocally
+
+ override def isInterrupted: Boolean = interrupted
+}
+
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8f0c5e78416c2..af5fd8e0ac00c 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -69,11 +69,13 @@ case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleId: Int,
mapId: Int,
- reduceId: Int)
+ reduceId: Int,
+ message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
- s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)"
+ s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
+ s"message=\n$message\n)"
}
}
@@ -81,15 +83,48 @@ case class FetchFailed(
* :: DeveloperApi ::
* Task failed due to a runtime exception. This is the most common failure case and also captures
* user program exceptions.
+ *
+ * `stackTrace` contains the stack trace of the exception itself. It still exists for backward
+ * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to
+ * create `ExceptionFailure` as it will handle the backward compatibility properly.
+ *
+ * `fullStackTrace` is a better representation of the stack trace because it contains the whole
+ * stack trace including the exception and its causes
*/
@DeveloperApi
case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement],
+ fullStackTrace: String,
metrics: Option[TaskMetrics])
extends TaskFailedReason {
- override def toErrorString: String = Utils.exceptionString(className, description, stackTrace)
+
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ }
+
+ override def toErrorString: String =
+ if (fullStackTrace == null) {
+ // fullStackTrace is added in 1.2.0
+ // If fullStackTrace is null, use the old error string for backward compatibility
+ exceptionString(className, description, stackTrace)
+ } else {
+ fullStackTrace
+ }
+
+ /**
+ * Return a nice string representation of the exception, including the stack trace.
+ * Note: It does not include the exception's causes, and is only used for backward compatibility.
+ */
+ private def exceptionString(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement]): String = {
+ val desc = if (description == null) "" else description
+ val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n")
+ s"$className: $desc\n$st"
+ }
}
/**
@@ -117,8 +152,8 @@ case object TaskKilled extends TaskFailedReason {
* the task crashed the JVM.
*/
@DeveloperApi
-case object ExecutorLostFailure extends TaskFailedReason {
- override def toErrorString: String = "ExecutorLostFailure (executor lost)"
+case class ExecutorLostFailure(execId: String) extends TaskFailedReason {
+ override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)"
}
/**
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 8ca731038e528..34078142f5385 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -23,8 +23,10 @@ import java.util.jar.{JarEntry, JarOutputStream}
import scala.collection.JavaConversions._
+import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
-import com.google.common.io.Files
+
+import org.apache.spark.util.Utils
/**
* Utilities for tests. Included in main codebase since it's used by multiple
@@ -42,8 +44,7 @@ private[spark] object TestUtils {
* in order to avoid interference between tests.
*/
def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = {
- val tempDir = Files.createTempDir()
- tempDir.deleteOnExit()
+ val tempDir = Utils.createTempDir()
val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value)
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
createJar(files, jarFile)
@@ -63,12 +64,7 @@ private[spark] object TestUtils {
jarStream.putNextEntry(jarEntry)
val in = new FileInputStream(file)
- val buffer = new Array[Byte](10240)
- var nRead = 0
- while (nRead <= 0) {
- nRead = in.read(buffer, 0, buffer.length)
- jarStream.write(buffer, 0, nRead)
- }
+ ByteStreams.copy(in, jarStream)
in.close()
}
jarStream.close()
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index a6123bd108c11..8e8f7f6c4fda2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -114,7 +114,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: JavaDoubleRDD): JavaDoubleRDD =
fromRDD(srdd.subtract(other))
@@ -233,11 +233,11 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* to the left except for the last which is closed
* e.g. for the array
* [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
- * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
* And on the input of 1 and 50 we would have a histogram of 1,0,0
*
* Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
- * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets
* to true.
* buckets must be sorted and not contain any duplicates.
* buckets array must be at least two elements
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 0846225e4f992..7af3538262fd6 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -32,12 +32,13 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.Partitioner._
-import org.apache.spark.SparkContext.rddToPairRDDFunctions
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
+import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
- mapAsJavaMap(rdd.reduceByKeyLocally(func))
+ mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func))
/** Count the number of elements for each key, and return the result to the master as a Map. */
- def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+ def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey())
/**
* :: Experimental ::
@@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
@Experimental
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
- rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+ rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap)
/**
* :: Experimental ::
@@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
@Experimental
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
- rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+ rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap)
/**
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
@@ -391,7 +392,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
fromRDD(rdd.subtract(other))
@@ -412,7 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtractByKey[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, V] = {
implicit val ctag: ClassTag[W] = fakeClassTag
@@ -614,7 +615,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
- def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+ def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap())
+
/**
* Pass each value in the key-value pair RDD through a map function without changing the keys;
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 545bc0e9e99ed..5a8e5bb1f721a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -21,15 +21,18 @@ import java.util.{Comparator, List => JList, Iterator => JIterator}
import java.lang.{Iterable => JIterable, Long => JLong}
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext}
+import org.apache.spark._
+import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.RDD
@@ -293,8 +296,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Applies a function f to all elements of this RDD.
*/
def foreach(f: VoidFunction[T]) {
- val cleanF = rdd.context.clean((x: T) => f.call(x))
- rdd.foreach(cleanF)
+ rdd.foreach(x => f.call(x))
}
/**
@@ -390,7 +392,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
- mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
+ mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
/**
* (Experimental) Approximate version of countByValue().
@@ -399,13 +401,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
timeout: Long,
confidence: Double
): PartialResult[java.util.Map[T, BoundedDouble]] =
- rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+ rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap)
/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
- rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+ rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap)
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
@@ -491,9 +493,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD as defined by
+ * Returns the top k (largest) elements from this RDD as defined by
* the specified Comparator[T].
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -505,9 +507,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD using the
+ * Returns the top k (largest) elements from this RDD using the
* natural ordering for T.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def top(num: Int): JList[T] = {
@@ -516,9 +518,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD as defined by
+ * Returns the first k (smallest) elements from this RDD as defined by
* the specified Comparator[T] and maintains the order.
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -550,9 +552,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD using the
+ * Returns the first k (smallest) elements from this RDD using the
* natural ordering for T while maintain the order.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def takeOrdered(num: Int): JList[T] = {
@@ -575,16 +577,44 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def name(): String = rdd.name
/**
- * :: Experimental ::
- * The asynchronous version of the foreach action.
- *
- * @param f the function to apply to all the elements of the RDD
- * @return a FutureAction for the action
+ * The asynchronous version of `count`, which returns a
+ * future for counting the number of elements in this RDD.
*/
- @Experimental
- def foreachAsync(f: VoidFunction[T]): FutureAction[Unit] = {
- import org.apache.spark.SparkContext._
- rdd.foreachAsync(x => f.call(x))
+ def countAsync(): JavaFutureAction[JLong] = {
+ new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf)
+ }
+
+ /**
+ * The asynchronous version of `collect`, which returns a future for
+ * retrieving an array containing all of the elements in this RDD.
+ */
+ def collectAsync(): JavaFutureAction[JList[T]] = {
+ new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava)
+ }
+
+ /**
+ * The asynchronous version of the `take` action, which returns a
+ * future for retrieving the first `num` elements of this RDD.
+ */
+ def takeAsync(num: Int): JavaFutureAction[JList[T]] = {
+ new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava)
}
+ /**
+ * The asynchronous version of the `foreach` action, which
+ * applies a function f to all the elements of this RDD.
+ */
+ def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = {
+ new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)),
+ { x => null.asInstanceOf[Void] })
+ }
+
+ /**
+ * The asynchronous version of the `foreachPartition` action, which
+ * applies a function f to each partition of this RDD.
+ */
+ def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = {
+ new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)),
+ { x => null.asInstanceOf[Void] })
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 791d853a015a1..97f5c9f257e09 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -28,11 +28,13 @@ import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.input.PortableDataStream
import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark._
-import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
+import org.apache.spark.AccumulatorParam._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
@@ -40,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
/**
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
* [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones.
+ *
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
*/
class JavaSparkContext(val sc: SparkContext)
extends JavaSparkContextVarargsWorkaround with Closeable {
@@ -103,6 +108,8 @@ class JavaSparkContext(val sc: SparkContext)
private[spark] val env = sc.env
+ def statusTracker = new JavaSparkStatusTracker(sc)
+
def isLocal: java.lang.Boolean = sc.isLocal
def sparkUser: String = sc.sparkUser
@@ -183,6 +190,8 @@ class JavaSparkContext(val sc: SparkContext)
def textFile(path: String, minPartitions: Int): JavaRDD[String] =
sc.textFile(path, minPartitions)
+
+
/**
* Read a directory of text files from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI. Each file is read as a single record and returned in a
@@ -196,7 +205,10 @@ class JavaSparkContext(val sc: SparkContext)
* hdfs://a-hdfs-path/part-nnnnn
* }}}
*
- * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`,
+ * Do
+ * {{{
+ * JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")
+ * }}}
*
*
then `rdd` contains
* {{{
@@ -223,6 +235,84 @@ class JavaSparkContext(val sc: SparkContext)
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, minPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
+ new JavaRDD(sc.binaryRecords(path, recordLength))
+ }
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
new file mode 100644
index 0000000000000..3300cad9efbab
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext}
+
+/**
+ * Low-level status reporting APIs for monitoring job and stage progress.
+ *
+ * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should
+ * be prepared to handle empty / missing information. For example, a job's stage ids may be known
+ * but the status API may not have any information about the details of those stages, so
+ * `getStageInfo` could potentially return `null` for a valid stage id.
+ *
+ * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs
+ * will provide information for the last `spark.ui.retainedStages` stages and
+ * `spark.ui.retainedJobs` jobs.
+ *
+ * NOTE: this class's constructor should be considered private and may be subject to change.
+ */
+class JavaSparkStatusTracker private[spark] (sc: SparkContext) {
+
+ /**
+ * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then
+ * returns all known jobs that are not associated with a job group.
+ *
+ * The returned list may contain running, failed, and completed jobs, and may vary across
+ * invocations of this method. This method does not guarantee the order of the elements in
+ * its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup)
+
+ /**
+ * Returns an array containing the ids of all active stages.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds()
+
+ /**
+ * Returns an array containing the ids of all active jobs.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds()
+
+ /**
+ * Returns job information, or `null` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull
+
+ /**
+ * Returns stage information, or `null` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index 22810cb1c662d..b52d0a5028e84 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,10 +19,20 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
+import scala.collection.convert.Wrappers.MapWrapper
+
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
option match {
case Some(value) => Optional.of(value)
case None => Optional.absent()
}
+
+ // Workaround for SPARK-3926 / SI-8911
+ def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
+ new SerializableMapWrapper(underlying)
+
+ class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
+ extends MapWrapper(underlying) with java.io.Serializable
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 49dc95f349eac..5ba66178e2b78 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
- conf: Broadcast[SerializableWritable[Configuration]],
- batchSize: Int) extends Converter[Any, Any] {
+ conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {
/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
@@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
- case w: Writable =>
- if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
+ case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 924141475383d..e0bc00e1eb249 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,16 +19,15 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.nio.charset.Charset
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
+
+import org.apache.spark.input.PortableDataStream
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
-import scala.reflect.ClassTag
-import scala.util.{Try, Success, Failure}
-import net.razorvine.pickle.{Pickler, Unpickler}
+import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -42,22 +41,22 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD(
- parent: RDD[_],
+ @transient parent: RDD[_],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
- override def getPartitions = parent.partitions
+ override def getPartitions = firstParent.partitions
- override val partitioner = if (preservePartitoning) parent.partitioner else None
+ override val partitioner = if (preservePartitoning) firstParent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
@@ -76,6 +75,7 @@ private[spark] class PythonRDD(
var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
+ writerThread.join()
if (reuse_worker && complete_cleanly) {
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
} else {
@@ -134,7 +134,7 @@ private[spark] class PythonRDD(
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
- throw new PythonException(new String(obj, "utf-8"),
+ throw new PythonException(new String(obj, UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
@@ -146,7 +146,9 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
- complete_cleanly = true
+ if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+ complete_cleanly = true
+ }
null
}
} catch {
@@ -155,6 +157,10 @@ private[spark] class PythonRDD(
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
+ case e: Exception if env.isStopped =>
+ logDebug("Exception thrown after context is stopped", e)
+ null // exit silently
+
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
@@ -196,7 +202,6 @@ private[spark] class PythonRDD(
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
@@ -225,8 +230,7 @@ private[spark] class PythonRDD(
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
@@ -235,8 +239,9 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
- PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
@@ -248,6 +253,11 @@ private[spark] class PythonRDD(
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
worker.shutdownOutput()
+ } finally {
+ // Release memory used by this thread for shuffles
+ env.shuffleMemoryManager.releaseMemoryForThisThread()
+ // Release memory used by this thread for unrolling blocks
+ env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
}
}
}
@@ -303,10 +313,10 @@ private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
+ val END_OF_STREAM = -4
}
private[spark] object PythonRDD extends Logging {
- val UTF8 = Charset.forName("UTF-8")
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
@@ -357,16 +367,8 @@ private[spark] object PythonRDD extends Logging {
}
}
- def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
- val file = new DataInputStream(new FileInputStream(filename))
- try {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
- } finally {
- file.close()
- }
+ def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
+ sc.broadcast(new PythonBroadcast(path))
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -386,22 +388,33 @@ private[spark] object PythonRDD extends Logging {
newIter.asInstanceOf[Iterator[String]].foreach { str =>
writeUTF(str, dataOut)
}
- case pair: Tuple2[_, _] =>
- pair._1 match {
- case bytePair: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
- dataOut.writeInt(pair._1.length)
- dataOut.write(pair._1)
- dataOut.writeInt(pair._2.length)
- dataOut.write(pair._2)
- }
- case stringPair: String =>
- newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
- writeUTF(pair._1, dataOut)
- writeUTF(pair._2, dataOut)
- }
- case other =>
- throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
+ case stream: PortableDataStream =>
+ newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, stream: PortableDataStream) =>
+ newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
+ case (key, stream) =>
+ writeUTF(key, dataOut)
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, value: String) =>
+ newIter.asInstanceOf[Iterator[(String, String)]].foreach {
+ case (key, value) =>
+ writeUTF(key, dataOut)
+ writeUTF(value, dataOut)
+ }
+ case (key: Array[Byte], value: Array[Byte]) =>
+ newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
+ case (key, value) =>
+ dataOut.writeInt(key.length)
+ dataOut.write(key)
+ dataOut.writeInt(value.length)
+ dataOut.write(value)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
@@ -431,7 +444,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -457,7 +470,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -483,7 +496,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -526,7 +539,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -552,7 +565,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -574,7 +587,7 @@ private[spark] object PythonRDD extends Logging {
}
def writeUTF(str: String, dataOut: DataOutputStream) {
- val bytes = str.getBytes(UTF8)
+ val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
@@ -735,107 +748,11 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}
-
-
- /**
- * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- */
- @deprecated("PySpark does not use it anymore", "1.1")
- def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- unpickle.loads(row) match {
- // in case of objects are pickled in batch mode
- case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // not in batch mode
- case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
- }
- }
- }
- }
-
- /**
- * Convert an RDD of serialized Python tuple to Array (no recursive conversions).
- * It is only used by pyspark.sql.
- */
- def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
-
- def toArray(obj: Any): Array[_] = {
- obj match {
- case objs: JArrayList[_] =>
- objs.toArray
- case obj if obj.getClass.isArray =>
- obj.asInstanceOf[Array[_]].toArray
- }
- }
-
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].map(toArray)
- } else {
- Seq(toArray(obj))
- }
- }
- }.toJavaRDD()
- }
-
- private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
- private val pickle = new Pickler()
- private var batch = 1
- private val buffer = new mutable.ArrayBuffer[Any]
-
- override def hasNext(): Boolean = iter.hasNext
-
- override def next(): Array[Byte] = {
- while (iter.hasNext && buffer.length < batch) {
- buffer += iter.next()
- }
- val bytes = pickle.dumps(buffer.toArray)
- val size = bytes.length
- // let 1M < size < 10M
- if (size < 1024 * 1024) {
- batch *= 2
- } else if (size > 1024 * 1024 * 10 && batch > 1) {
- batch /= 2
- }
- buffer.clear()
- bytes
- }
- }
-
- /**
- * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
- * PySpark.
- */
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
- }
-
- /**
- * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
- */
- def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]]
- } else {
- Seq(obj)
- }
- }
- }.toJavaRDD()
- }
}
private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
- override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
+ override def call(arr: Array[Byte]) : String = new String(arr, UTF_8)
}
/**
@@ -890,3 +807,49 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
}
}
}
+
+/**
+ * An Wrapper for Python Broadcast, which is written into disk by Python. It also will
+ * write the data into disk after deserialization, then Python can read it from disks.
+ */
+private[spark] class PythonBroadcast(@transient var path: String) extends Serializable {
+
+ /**
+ * Read data from disks, then copy it to `out`
+ */
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
+ val in = new FileInputStream(new File(path))
+ try {
+ Utils.copyStream(in, out)
+ } finally {
+ in.close()
+ }
+ }
+
+ /**
+ * Write data into disk, using randomly generated name.
+ */
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
+ val file = File.createTempFile("broadcast", "", dir)
+ path = file.getAbsolutePath
+ val out = new FileOutputStream(file)
+ try {
+ Utils.copyStream(in, out)
+ } finally {
+ out.close()
+ }
+ }
+
+ /**
+ * Delete the file once the object is GCed.
+ */
+ override def finalize() {
+ if (!path.isEmpty) {
+ val file = new File(path)
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 71bdf0fe1b917..e314408c067e9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Create and start the worker
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
val worker = pb.start()
// Redirect worker stdout and stderr
@@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
try {
// Create and start the daemon
- val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon"))
+ val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
workerEnv.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
+ workerEnv.put("PYTHONUNBUFFERED", "YES")
daemon = pb.start()
val in = new DataInputStream(daemon.getInputStream)
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index 7903457b17e13..a4153aaa926f8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -18,8 +18,13 @@
package org.apache.spark.api.python
import java.nio.ByteOrder
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.spark.api.java.JavaRDD
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Failure
import scala.util.Try
@@ -29,7 +34,7 @@ import org.apache.spark.{Logging, SparkException}
import org.apache.spark.rdd.RDD
/** Utilities for serialization / deserialization between Python and Java, using Pickle. */
-private[python] object SerDeUtil extends Logging {
+private[spark] object SerDeUtil extends Logging {
// Unpickle array.array generated by Python 2.6
class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor {
// /* Description of types */
@@ -76,8 +81,84 @@ private[python] object SerDeUtil extends Logging {
}
}
+ private var initialized = false
+ // This should be called before trying to unpickle array.array from Python
+ // In cluster mode, this should be put in closure
def initialize() = {
- Unpickler.registerConstructor("array", "array", new ArrayConstructor())
+ synchronized{
+ if (!initialized) {
+ Unpickler.registerConstructor("array", "array", new ArrayConstructor())
+ initialized = true
+ }
+ }
+ }
+ initialize()
+
+
+ /**
+ * Convert an RDD of Java objects to Array (no recursive conversions).
+ * It is only used by pyspark.sql.
+ */
+ def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
+ jrdd.rdd.map {
+ case objs: JArrayList[_] =>
+ objs.toArray
+ case obj if obj.getClass.isArray =>
+ obj.asInstanceOf[Array[_]].toArray
+ }.toJavaRDD()
+ }
+
+ /**
+ * Choose batch size based on size of objects
+ */
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize()
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
}
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
@@ -119,17 +200,18 @@ private[python] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())
+
rdd.mapPartitions { iter =>
- val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
- if (batchSize > 1) {
- cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
+ if (batchSize == 0) {
+ new AutoBatchedPickler(cleaned)
} else {
- cleaned.map(pickle.dumps(_))
+ val pickle = new Pickler
+ cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}
@@ -137,35 +219,22 @@ private[python] object SerDeUtil extends Logging {
/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
- def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
+ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
- Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
+ Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
- pyRDD.mapPartitions { iter =>
- val unpickle = new Unpickler
- val unpickled =
- if (batchSerialized) {
- iter.flatMap { batch =>
- unpickle.loads(batch) match {
- case objs: java.util.List[_] => collectionAsScalaIterable(objs)
- case other => throw new SparkException(
- s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
- }
- }
- } else {
- iter.map(unpickle.loads(_))
- }
- unpickled.map {
- case obj if isPair(obj) =>
- // we only accept (K, V)
- val arr = obj.asInstanceOf[Array[_]]
- (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
- case other => throw new SparkException(
- s"RDD element of type ${other.getClass.getName} cannot be used")
- }
+
+ val rdd = pythonToJava(pyRDD, batched).rdd
+ rdd.first match {
+ case obj if isPair(obj) =>
+ // we only accept (K, V)
+ case other => throw new SparkException(
+ s"RDD element of type ${other.getClass.getName} cannot be used")
+ }
+ rdd.map { obj =>
+ val arr = obj.asInstanceOf[Array[_]]
+ (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}
-
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index d11db978b842e..c0cbd28a845be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -18,7 +18,8 @@
package org.apache.spark.api.python
import java.io.{DataOutput, DataInput}
-import java.nio.charset.Charset
+
+import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.io._
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
@@ -136,7 +137,7 @@ object WriteInputFormatTestDataGenerator {
sc.parallelize(intKeys).saveAsSequenceFile(intPath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath)
- sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) }
+ sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) }
).saveAsSequenceFile(bytesPath)
val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false))
sc.parallelize(bools).saveAsSequenceFile(boolPath)
@@ -175,11 +176,11 @@ object WriteInputFormatTestDataGenerator {
// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
- ("1", TestWritable("test1", 123, 54.0)),
- ("2", TestWritable("test2", 456, 8762.3)),
- ("1", TestWritable("test3", 123, 423.1)),
- ("3", TestWritable("test56", 456, 423.5)),
- ("2", TestWritable("test2", 123, 5435.2))
+ ("1", TestWritable("test1", 1, 1.0)),
+ ("2", TestWritable("test2", 2, 2.3)),
+ ("3", TestWritable("test3", 3, 3.1)),
+ ("5", TestWritable("test56", 5, 5.5)),
+ ("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 15fd30e65761d..a5ea478f231d7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -20,6 +20,8 @@ package org.apache.spark.broadcast
import java.io.Serializable
import org.apache.spark.SparkException
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
import scala.reflect.ClassTag
@@ -37,7 +39,7 @@ import scala.reflect.ClassTag
*
* {{{
* scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
- * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c)
+ * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
*
* scala> broadcastVar.value
* res0: Array[Int] = Array(1, 2, 3)
@@ -52,7 +54,7 @@ import scala.reflect.ClassTag
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
-abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
+abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {
/**
* Flag signifying whether the broadcast variable is valid
@@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
*/
@volatile private var _isValid = true
+ private var _destroySite = ""
+
/** Get the broadcasted value. */
def value: T = {
assertValid()
@@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
doUnpersist(blocking)
}
+
+ /**
+ * Destroy all data and metadata related to this broadcast variable. Use this with caution;
+ * once a broadcast variable has been destroyed, it cannot be used again.
+ * This method blocks until destroy has completed
+ */
+ def destroy() {
+ destroy(blocking = true)
+ }
+
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
+ * @param blocking Whether to block until destroy has completed
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
+ _destroySite = Utils.getCallSite().shortForm
+ logInfo("Destroying %s (from %s)".format(toString, _destroySite))
doDestroy(blocking)
}
@@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
/** Check if this broadcast is valid. If not valid, exception is thrown. */
protected def assertValid() {
if (!_isValid) {
- throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
+ throw new SparkException(
+ "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 942dc7d7eac87..31f0a462f84d8 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -72,13 +72,13 @@ private[spark] class HttpBroadcast[T: ClassTag](
}
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
/** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
@@ -163,18 +163,23 @@ private[broadcast] object HttpBroadcast extends Logging {
private def write(id: Long, value: Any) {
val file = getFile(id)
- val out: OutputStream = {
- if (compress) {
- compressionCodec.compressedOutputStream(new FileOutputStream(file))
- } else {
- new BufferedOutputStream(new FileOutputStream(file), bufferSize)
+ val fileOutputStream = new FileOutputStream(file)
+ try {
+ val out: OutputStream = {
+ if (compress) {
+ compressionCodec.compressedOutputStream(fileOutputStream)
+ } else {
+ new BufferedOutputStream(fileOutputStream, bufferSize)
+ }
}
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serOut = ser.serializeStream(out)
+ serOut.writeObject(value)
+ serOut.close()
+ files += file
+ } finally {
+ fileOutputStream.close()
}
- val ser = SparkEnv.get.serializer.newInstance()
- val serOut = ser.serializeStream(out)
- serOut.writeObject(value)
- serOut.close()
- files += file
}
private def read[T: ClassTag](id: Long): T = {
@@ -186,10 +191,12 @@ private[broadcast] object HttpBroadcast extends Logging {
logDebug("broadcast security enabled")
val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
uc = newuri.toURL.openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
uc.setAllowUserInteraction(false)
} else {
logDebug("broadcast not using security")
uc = new URL(url).openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
}
val in = {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 42d58682a1e23..94142d33369c7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -26,8 +26,9 @@ import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.ByteArrayChunkOutputStream
/**
@@ -46,53 +47,66 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream
* This prevents the driver from being the bottleneck in sending out multiple copies of the
* broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
*
+ * When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
+ *
* @param obj object to broadcast
- * @param isLocal whether Spark is running in local mode (single JVM process).
* @param id A unique identifier for the broadcast variable.
*/
-private[spark] class TorrentBroadcast[T: ClassTag](
- obj : T,
- @transient private val isLocal: Boolean,
- id: Long)
+private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
- * Value of the broadcast object. On driver, this is set directly by the constructor.
- * On executors, this is reconstructed by [[readObject]], which builds this value by reading
- * blocks from the driver and/or other executors.
+ * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
+ * which builds this value by reading blocks from the driver and/or other executors.
+ *
+ * On the driver, if the value is required, it is read lazily from the block manager.
*/
- @transient private var _value: T = obj
+ @transient private lazy val _value: T = readBroadcastBlock()
+
+ /** The compression codec to use, or None if compression is disabled */
+ @transient private var compressionCodec: Option[CompressionCodec] = _
+ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
+ @transient private var blockSize: Int = _
+
+ private def setConf(conf: SparkConf) {
+ compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
+ Some(CompressionCodec.createCodec(conf))
+ } else {
+ None
+ }
+ blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
+ }
+ setConf(SparkEnv.get.conf)
private val broadcastId = BroadcastBlockId(id)
/** Total number of blocks this broadcast variable contains. */
- private val numBlocks: Int = writeBlocks()
+ private val numBlocks: Int = writeBlocks(obj)
- override protected def getValue() = _value
+ override protected def getValue() = {
+ _value
+ }
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
- *
+ * @param value the object to divide
* @return number of blocks this broadcast variable is divided into
*/
- private def writeBlocks(): Int = {
- // For local mode, just put the object in the BlockManager so we can find it later.
- SparkEnv.get.blockManager.putSingle(
- broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
-
- if (!isLocal) {
- val blocks = TorrentBroadcast.blockifyObject(_value)
- blocks.zipWithIndex.foreach { case (block, i) =>
- SparkEnv.get.blockManager.putBytes(
- BroadcastBlockId(id, "piece" + i),
- block,
- StorageLevel.MEMORY_AND_DISK_SER,
- tellMaster = true)
- }
- blocks.length
- } else {
- 0
+ private def writeBlocks(value: T): Int = {
+ // Store a copy of the broadcast variable in the driver so that tasks run on the driver
+ // do not create a duplicate copy of the broadcast variable's value.
+ SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
+ tellMaster = false)
+ val blocks =
+ TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
+ blocks.zipWithIndex.foreach { case (block, i) =>
+ SparkEnv.get.blockManager.putBytes(
+ BroadcastBlockId(id, "piece" + i),
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
}
+ blocks.length
}
/** Fetch torrent blocks from the driver and/or other executors. */
@@ -104,29 +118,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
-
- // First try getLocalBytes because there is a chance that previous attempts to fetch the
+ logDebug(s"Reading piece $pieceId of $broadcastId")
+ // First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
- var blockOpt = bm.getLocalBytes(pieceId)
- if (!blockOpt.isDefined) {
- blockOpt = bm.getRemoteBytes(pieceId)
- blockOpt match {
- case Some(block) =>
- // If we found the block from remote executors/driver's BlockManager, put the block
- // in this executor's BlockManager.
- SparkEnv.get.blockManager.putBytes(
- pieceId,
- block,
- StorageLevel.MEMORY_AND_DISK_SER,
- tellMaster = true)
-
- case None =>
- throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
- }
+ def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
+ def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
+ // If we found the block from remote executors/driver's BlockManager, put the block
+ // in this executor's BlockManager.
+ SparkEnv.get.blockManager.putBytes(
+ pieceId,
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+ block
}
- // If we get here, the option is defined.
- blocks(pid) = blockOpt.get
+ val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
+ throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
+ blocks(pid) = block
}
blocks
}
@@ -147,75 +156,62 @@ private[spark] class TorrentBroadcast[T: ClassTag](
}
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
- /** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
+ private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
+ setConf(SparkEnv.get.conf)
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
- _value = x.asInstanceOf[T]
+ x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
- val start = System.nanoTime()
+ val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks()
- val time = (System.nanoTime() - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
- _value = TorrentBroadcast.unBlockifyObject[T](blocks)
+ val obj = TorrentBroadcast.unBlockifyObject[T](
+ blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
- broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ obj
}
}
}
+
}
private object TorrentBroadcast extends Logging {
- /** Size of each block. Default value is 4MB. */
- private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
- private var initialized = false
- private var conf: SparkConf = null
- private var compress: Boolean = false
- private var compressionCodec: CompressionCodec = null
-
- def initialize(_isDriver: Boolean, conf: SparkConf) {
- TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
- synchronized {
- if (!initialized) {
- compress = conf.getBoolean("spark.broadcast.compress", true)
- compressionCodec = CompressionCodec.createCodec(conf)
- initialized = true
- }
- }
- }
- def stop() {
- initialized = false
- }
-
- def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
- val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE)
- val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
- val ser = SparkEnv.get.serializer.newInstance()
+ def blockifyObject[T: ClassTag](
+ obj: T,
+ blockSize: Int,
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
+ val bos = new ByteArrayChunkOutputStream(blockSize)
+ val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
+ val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
bos.toArrays.map(ByteBuffer.wrap)
}
- def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
+ def unBlockifyObject[T: ClassTag](
+ blocks: Array[ByteBuffer],
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): T = {
+ require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
- val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
-
- val ser = SparkEnv.get.serializer.newInstance()
+ val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
+ val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
@@ -227,6 +223,7 @@ private object TorrentBroadcast extends Logging {
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = {
+ logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index ad0f701d7a98f..fb024c12094f2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -28,14 +28,13 @@ import org.apache.spark.{SecurityManager, SparkConf}
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- TorrentBroadcast.initialize(isDriver, conf)
- }
+ override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
- override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
- new TorrentBroadcast[T](value_, isLocal, id)
+ override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = {
+ new TorrentBroadcast[T](value_, id)
+ }
- override def stop() { TorrentBroadcast.stop() }
+ override def stop() { }
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 065ddda50e65e..f2687ce6b42b4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
System.exit(-1)
- case AssociationErrorEvent(cause, _, remoteAddress, _) =>
+ case AssociationErrorEvent(cause, _, remoteAddress, _, _) =>
println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.")
println(s"Cause was: $cause")
System.exit(-1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 39150deab863c..2e1e52906ceeb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -17,6 +17,8 @@
package org.apache.spark.deploy
+import java.net.{URI, URISyntaxException}
+
import scala.collection.mutable.ListBuffer
import org.apache.log4j.Level
@@ -73,7 +75,8 @@ private[spark] class ClientArguments(args: Array[String]) {
if (!ClientArguments.isValidJarUrl(_jarUrl)) {
println(s"Jar url '${_jarUrl}' is not in valid format.")
- println(s"Must be a jar file path in URL format (e.g. hdfs://XX.jar, file://XX.jar)")
+ println(s"Must be a jar file path in URL format " +
+ "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)")
printUsageAndExit(-1)
}
@@ -114,5 +117,12 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
- def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar")
+ def isValidJarUrl(s: String): Boolean = {
+ try {
+ val uri = new URI(s)
+ uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar")
+ } catch {
+ case _: URISyntaxException => false
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index a7368f9f3dfbe..c46f84de8444a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -71,6 +71,8 @@ private[deploy] object DeployMessages {
case class RegisterWorkerFailed(message: String) extends DeployMessage
+ case class ReconnectWorker(masterUrl: String) extends DeployMessage
+
case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
@@ -90,6 +92,8 @@ private[deploy] object DeployMessages {
case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
+ case object ReregisterWithMaster // used when a worker attempts to reconnect to a master
+
// AppClient to Master
case class RegisterApplication(appDescription: ApplicationDescription)
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 79b4d7ea41a33..039c8719e2867 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -34,7 +34,8 @@ object PythonRunner {
val pythonFile = args(0)
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
- val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf
+ val pythonExec =
+ sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python"))
// Format python file paths before adding them to the PYTHONPATH
val formattedPythonFile = formatPath(pythonFile)
@@ -57,6 +58,7 @@ object PythonRunner {
val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs)
val env = builder.environment()
env.put("PYTHONPATH", pythonPath)
+ // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
@@ -85,8 +87,8 @@ object PythonRunner {
// Strip the URI scheme from the path
formattedPath =
new URI(formattedPath).getScheme match {
- case Utils.windowsDrive(d) if windows => formattedPath
case null => formattedPath
+ case Utils.windowsDrive(d) if windows => formattedPath
case _ => new URI(formattedPath).getPath
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index fe0ad9ebbca12..60ee115e393ce 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,15 +17,19 @@
package org.apache.spark.deploy
+import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.FileSystem.Statistics
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
import scala.collection.JavaConversions._
@@ -121,6 +125,64 @@ class SparkHadoopUtil extends Logging {
UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
}
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes read. If
+ * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes read on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
+ : Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
+ val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesRead = f()
+ Some(() => f() - baselineBytesRead)
+ } catch {
+ case e: NoSuchMethodException => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e)
+ None
+ }
+ }
+ }
+
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes written. If
+ * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes written on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration)
+ : Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
+ val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesWritten = f()
+ Some(() => f() - baselineBytesWritten)
+ } catch {
+ case e: NoSuchMethodException => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
+ None
+ }
+ }
+ }
+
+ private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = {
+ val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
+ val scheme = qualifiedPath.toUri().getScheme()
+ val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
+ }
+
+ private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
+ val statisticsDataClass =
+ Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ statisticsDataClass.getDeclaredMethod(methodName)
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index f97bf67fa5a3b..00f291823e984 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -158,8 +158,9 @@ object SparkSubmit {
args.files = mergeFileLists(args.files, args.primaryResource)
}
args.files = mergeFileLists(args.files, args.pyFiles)
- // Format python file paths properly before adding them to the PYTHONPATH
- sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",")
+ if (args.pyFiles != null) {
+ sysProps("spark.submit.pyFiles") = args.pyFiles
+ }
}
// Special flag to avoid deprecation warnings at the client
@@ -273,15 +274,32 @@ object SparkSubmit {
}
}
- // Properties given with --conf are superceded by other options, but take precedence over
- // properties in the defaults file.
+ // Load any properties specified through --conf and the default properties file
for ((k, v) <- args.sparkProperties) {
sysProps.getOrElseUpdate(k, v)
}
- // Read from default spark properties, if any
- for ((k, v) <- args.defaultSparkProperties) {
- sysProps.getOrElseUpdate(k, v)
+ // Resolve paths in certain spark properties
+ val pathConfigs = Seq(
+ "spark.jars",
+ "spark.files",
+ "spark.yarn.jar",
+ "spark.yarn.dist.files",
+ "spark.yarn.dist.archives")
+ pathConfigs.foreach { config =>
+ // Replace old URIs with resolved URIs, if they exist
+ sysProps.get(config).foreach { oldValue =>
+ sysProps(config) = Utils.resolveURIs(oldValue)
+ }
+ }
+
+ // Resolve and format python file paths properly before adding them to the PYTHONPATH.
+ // The resolving part is redundant in the case of --py-files, but necessary if the user
+ // explicitly sets `spark.submit.pyFiles` in his/her default properties file.
+ sysProps.get("spark.submit.pyFiles").foreach { pyFiles =>
+ val resolvedPyFiles = Utils.resolveURIs(pyFiles)
+ val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",")
+ sysProps("spark.submit.pyFiles") = formattedPyFiles
}
(childArgs, childClasspath, sysProps, childMainClass)
@@ -322,11 +340,16 @@ object SparkSubmit {
e.printStackTrace(printStream)
if (childMainClass.contains("thriftserver")) {
println(s"Failed to load main class $childMainClass.")
- println("You need to build Spark with -Phive.")
+ println("You need to build Spark with -Phive and -Phive-thriftserver.")
}
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
+ // SPARK-4170
+ if (classOf[scala.App].isAssignableFrom(mainClass)) {
+ printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.")
+ }
+
val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass)
if (!Modifier.isStatic(mainMethod.getModifiers)) {
throw new IllegalStateException("The main method in the given main class must be static")
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 57b251ff47714..f0e9ee67f6a67 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,14 +17,10 @@
package org.apache.spark.deploy
-import java.io.{File, FileInputStream, IOException}
-import java.util.Properties
import java.util.jar.JarFile
-import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import org.apache.spark.SparkException
import org.apache.spark.util.Utils
/**
@@ -63,9 +59,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
val defaultProperties = new HashMap[String, String]()
if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile")
Option(propertiesFile).foreach { filename =>
- val file = new File(filename)
- SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) =>
- if (k.startsWith("spark")) {
+ Utils.getPropertiesFromFile(filename).foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
defaultProperties(k) = v
if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v")
} else {
@@ -76,51 +71,54 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
defaultProperties
}
- // Respect SPARK_*_MEMORY for cluster mode
- driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull
- executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull
-
+ // Set parameters from command line arguments
parseOpts(args.toList)
- mergeSparkProperties()
+ // Populate `sparkProperties` map from properties file
+ mergeDefaultSparkProperties()
+ // Use `sparkProperties` map along with env vars to fill in any missing parameters
+ loadEnvironmentArguments()
+
checkRequiredArguments()
/**
- * Fill in any undefined values based on the default properties file or options passed in through
- * the '--conf' flag.
+ * Merge values from the default properties file with those specified through --conf.
+ * When this is called, `sparkProperties` is already filled with configs from the latter.
*/
- private def mergeSparkProperties(): Unit = {
+ private def mergeDefaultSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
- if (propertiesFile == null) {
- val sep = File.separator
- val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf")
- val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig)
-
- confDir.foreach { sparkConfDir =>
- val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf"
- val file = new File(defaultPath)
- if (file.exists()) {
- propertiesFile = file.getAbsolutePath
- }
+ propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env))
+ // Honor --conf before the defaults file
+ defaultSparkProperties.foreach { case (k, v) =>
+ if (!sparkProperties.contains(k)) {
+ sparkProperties(k) = v
}
}
+ }
- val properties = HashMap[String, String]()
- properties.putAll(defaultSparkProperties)
- properties.putAll(sparkProperties)
-
- // Use properties file as fallback for values which have a direct analog to
- // arguments in this script.
- master = Option(master).orElse(properties.get("spark.master")).orNull
- executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull
- executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull
+ /**
+ * Load arguments from environment variables, Spark properties etc.
+ */
+ private def loadEnvironmentArguments(): Unit = {
+ master = Option(master)
+ .orElse(sparkProperties.get("spark.master"))
+ .orElse(env.get("MASTER"))
+ .orNull
+ driverMemory = Option(driverMemory)
+ .orElse(sparkProperties.get("spark.driver.memory"))
+ .orElse(env.get("SPARK_DRIVER_MEMORY"))
+ .orNull
+ executorMemory = Option(executorMemory)
+ .orElse(sparkProperties.get("spark.executor.memory"))
+ .orElse(env.get("SPARK_EXECUTOR_MEMORY"))
+ .orNull
+ executorCores = Option(executorCores)
+ .orElse(sparkProperties.get("spark.executor.cores"))
+ .orNull
totalExecutorCores = Option(totalExecutorCores)
- .orElse(properties.get("spark.cores.max"))
+ .orElse(sparkProperties.get("spark.cores.max"))
.orNull
- name = Option(name).orElse(properties.get("spark.app.name")).orNull
- jars = Option(jars).orElse(properties.get("spark.jars")).orNull
-
- // This supports env vars in older versions of Spark
- master = Option(master).orElse(env.get("MASTER")).orNull
+ name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
+ jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
// Try to set main class from JAR if no --class argument is given
@@ -147,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
/** Ensure that required fields exists. Call this only once all defaults are loaded. */
- private def checkRequiredArguments() = {
+ private def checkRequiredArguments(): Unit = {
if (args.length == 0) {
printUsageAndExit(-1)
}
@@ -182,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- override def toString = {
+ override def toString = {
s"""Parsed arguments:
| master $master
| deployMode $deployMode
@@ -190,7 +188,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| executorCores $executorCores
| totalExecutorCores $totalExecutorCores
| propertiesFile $propertiesFile
- | extraSparkProperties $sparkProperties
| driverMemory $driverMemory
| driverCores $driverCores
| driverExtraClassPath $driverExtraClassPath
@@ -209,8 +206,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| jars $jars
| verbose $verbose
|
- |Default properties from $propertiesFile:
- |${defaultSparkProperties.mkString(" ", "\n ", "\n")}
+ |Spark properties used, including those specified through
+ | --conf and those from the properties file $propertiesFile:
+ |${sparkProperties.mkString(" ", "\n ", "\n")}
""".stripMargin
}
@@ -343,7 +341,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = {
val outStream = SparkSubmit.printStream
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
@@ -397,23 +395,3 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
SparkSubmit.exitFn()
}
}
-
-object SparkSubmitArguments {
- /** Load properties present in the given file. */
- def getPropertiesFromFile(file: File): Seq[(String, String)] = {
- require(file.exists(), s"Properties file $file does not exist")
- require(file.isFile(), s"Properties file $file is not a normal file")
- val inputStream = new FileInputStream(file)
- try {
- val properties = new Properties()
- properties.load(inputStream)
- properties.stringPropertyNames().toSeq.map(k => (k, properties(k).trim))
- } catch {
- case e: IOException =>
- val message = s"Failed when loading Spark properties file $file"
- throw new SparkException(message, e)
- } finally {
- inputStream.close()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index a64170a47bc1c..d2687faad62b1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -68,7 +68,7 @@ private[spark] object SparkSubmitDriverBootstrapper {
assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set")
// Parse the properties file for the equivalent spark.driver.* configs
- val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap
+ val properties = Utils.getPropertiesFromFile(propertiesFile)
val confDriverMemory = properties.get("spark.driver.memory")
val confLibraryPath = properties.get("spark.driver.extraLibraryPath")
val confClasspath = properties.get("spark.driver.extraClassPath")
@@ -82,17 +82,8 @@ private[spark] object SparkSubmitDriverBootstrapper {
.orElse(confDriverMemory)
.getOrElse(defaultDriverMemory)
- val newLibraryPath =
- if (submitLibraryPath.isDefined) {
- // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS
- ""
- } else {
- confLibraryPath.map("-Djava.library.path=" + _).getOrElse("")
- }
-
val newClasspath =
if (submitClasspath.isDefined) {
- // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH
classpath
} else {
classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("")
@@ -114,7 +105,6 @@ private[spark] object SparkSubmitDriverBootstrapper {
val command: Seq[String] =
Seq(runner) ++
Seq("-cp", newClasspath) ++
- Seq(newLibraryPath) ++
filteredJavaOpts ++
Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++
Seq("org.apache.spark.deploy.SparkSubmit") ++
@@ -130,8 +120,25 @@ private[spark] object SparkSubmitDriverBootstrapper {
// Start the driver JVM
val filteredCommand = command.filter(_.nonEmpty)
val builder = new ProcessBuilder(filteredCommand)
+ val env = builder.environment()
+
+ if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) {
+ val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName)
+ env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator")))
+ }
+
val process = builder.start()
+ // If we kill an app while it's running, its sub-process should be killed too.
+ Runtime.getRuntime().addShutdownHook(new Thread() {
+ override def run() = {
+ if (process != null) {
+ process.destroy()
+ process.waitFor()
+ }
+ }
+ })
+
// Redirect stdout and stderr from the child JVM
val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout")
val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr")
@@ -142,14 +149,15 @@ private[spark] object SparkSubmitDriverBootstrapper {
// subprocess there already reads directly from our stdin, so we should avoid spawning a
// thread that contends with the subprocess in reading from System.in.
val isWindows = Utils.isWindows
- val isPySparkShell = sys.env.contains("PYSPARK_SHELL")
+ val isSubprocess = sys.env.contains("IS_SUBPROCESS")
if (!isWindows) {
val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin")
stdinThread.start()
- // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM
- // should terminate on broken pipe, which signals that the parent process has exited. In
- // Windows, the termination logic for the PySpark shell is handled in java_gateway.py
- if (isPySparkShell) {
+ // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on
+ // broken pipe, signaling that the parent process has exited. This is the case if the
+ // application is launched directly from python, as in the PySpark shell. In Windows,
+ // the termination logic is handled in java_gateway.py
+ if (isSubprocess) {
stdinThread.join()
process.destroy()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 32790053a6be8..98a93d1fcb2a3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -154,7 +154,7 @@ private[spark] class AppClient(
logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
- case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) =>
+ case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) =>
logWarning(s"Could not connect to $address: $cause")
case StopAppClient =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 481f6c93c6a8d..82a54dbfb5330 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -29,22 +29,27 @@ import org.apache.spark.scheduler._
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.Utils
+/**
+ * A class that provides application history from event logs stored in the file system.
+ * This provider checks for new finished applications in the background periodically and
+ * renders the history application UI by parsing the associated event logs.
+ */
private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider
with Logging {
+ import FsHistoryProvider._
+
private val NOT_STARTED = ""
// Interval between each check for event log updates
private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval",
conf.getInt("spark.history.updateInterval", 10)) * 1000
- private val logDir = conf.get("spark.history.fs.logDirectory", null)
- private val resolvedLogDir = Option(logDir)
- .map { d => Utils.resolveURI(d) }
- .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") }
+ private val logDir = conf.getOption("spark.history.fs.logDirectory")
+ .map { d => Utils.resolveURI(d).toString }
+ .getOrElse(DEFAULT_LOG_DIR)
- private val fs = Utils.getHadoopFileSystem(resolvedLogDir,
- SparkHadoopUtil.get.newConfiguration(conf))
+ private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf))
// A timestamp of when the disk was last accessed to check for log updates
private var lastLogCheckTimeMs = -1L
@@ -87,14 +92,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private def initialize() {
// Validate the log directory.
- val path = new Path(resolvedLogDir)
+ val path = new Path(logDir)
if (!fs.exists(path)) {
- throw new IllegalArgumentException(
- "Logging directory specified does not exist: %s".format(resolvedLogDir))
+ var msg = s"Log directory specified does not exist: $logDir."
+ if (logDir == DEFAULT_LOG_DIR) {
+ msg += " Did you configure the correct one through spark.fs.history.logDirectory?"
+ }
+ throw new IllegalArgumentException(msg)
}
if (!fs.getFileStatus(path).isDir) {
throw new IllegalArgumentException(
- "Logging directory specified is not a directory: %s".format(resolvedLogDir))
+ "Logging directory specified is not a directory: %s".format(logDir))
}
checkForLogs()
@@ -112,7 +120,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
val ui = {
val conf = this.conf.clone()
val appSecManager = new SecurityManager(conf)
- new SparkUI(conf, appSecManager, replayBus, appId,
+ SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId,
s"${HistoryServer.UI_PATH_PREFIX}/$appId")
// Do not call ui.bind() to avoid creating a new server for each application
}
@@ -134,8 +142,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
}
- override def getConfig(): Map[String, String] =
- Map("Event Log Location" -> resolvedLogDir.toString)
+ override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString)
/**
* Builds the application list based on the current contents of the log directory.
@@ -146,7 +153,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
lastLogCheckTimeMs = getMonotonicTimeMs()
logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs))
try {
- val logStatus = fs.listStatus(new Path(resolvedLogDir))
+ val logStatus = fs.listStatus(new Path(logDir))
val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]()
// Load all new logs from the log directory. Only directories that have a modification time
@@ -244,6 +251,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
+private object FsHistoryProvider {
+ val DEFAULT_LOG_DIR = "file:/tmp/spark-events"
+}
+
private class FsApplicationHistoryInfo(
val logDir: String,
id: String,
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index d25c29113d6da..5fdc350cd8512 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -58,7 +58,13 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
++
appTable
} else {
-
No Completed Applications Found
+
No completed applications found!
++
+
Did you specify the correct logging directory?
+ Please verify your setting of
+ spark.history.fs.logDirectory and whether you have the permissions to
+ access it. It is also possible that your application did not run to
+ completion or did not stop the SparkContext.
+
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index 25fc76c23e0fb..b1270ade9f750 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -17,26 +17,33 @@
package org.apache.spark.deploy.history
-import org.apache.spark.SparkConf
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.util.Utils
/**
* Command-line parser for the master.
*/
-private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) {
- private var logDir: String = null
+private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging {
+ private var propertiesFile: String = null
parse(args.toList)
private def parse(args: List[String]): Unit = {
args match {
case ("--dir" | "-d") :: value :: tail =>
- logDir = value
+ logWarning("Setting log directory through the command line is deprecated as of " +
+ "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.")
conf.set("spark.history.fs.logDirectory", value)
+ System.setProperty("spark.history.fs.logDirectory", value)
parse(tail)
case ("--help" | "-h") :: tail =>
printUsageAndExit(0)
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
case Nil =>
case _ =>
@@ -44,10 +51,17 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
}
}
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
private def printUsageAndExit(exitCode: Int) {
System.err.println(
"""
- |Usage: HistoryServer
+ |Usage: HistoryServer [options]
+ |
+ |Options:
+ | --properties-file FILE Path to a custom Spark properties file.
+ | Default is conf/spark-defaults.conf.
|
|Configuration options can be set by setting the corresponding JVM system property.
|History Server options are always available; additional options depend on the provider.
@@ -64,9 +78,10 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
| (default 50)
|FsHistoryProvider options:
|
- | spark.history.fs.logDirectory Directory where app logs are stored (required)
- | spark.history.fs.updateInterval How often to reload log data from storage (in seconds,
- | default 10)
+ | spark.history.fs.logDirectory Directory where app logs are stored
+ | (default: file:/tmp/spark-events)
+ | spark.history.fs.updateInterval How often to reload log data from storage
+ | (in seconds, default: 10)
|""".stripMargin)
System.exit(exitCode)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index c3ca43f8d0734..ad7d81747c377 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -24,7 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.ApplicationDescription
+import org.apache.spark.util.Utils
private[spark] class ApplicationInfo(
val startTime: Long,
@@ -46,7 +48,7 @@ private[spark] class ApplicationInfo(
init()
- private def readObject(in: java.io.ObjectInputStream): Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
index 80b570a44af18..9d3d7938c6ccb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
@@ -19,7 +19,9 @@ package org.apache.spark.deploy.master
import java.util.Date
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.DriverDescription
+import org.apache.spark.util.Utils
private[spark] class DriverInfo(
val startTime: Long,
@@ -36,7 +38,7 @@ private[spark] class DriverInfo(
init()
- private def readObject(in: java.io.ObjectInputStream): Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index aa85aa060d9c1..36a2e2c6a6349 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -19,10 +19,13 @@ package org.apache.spark.deploy.master
import java.io._
+import scala.reflect.ClassTag
+
import akka.serialization.Serialization
import org.apache.spark.Logging
+
/**
* Stores data in a single on-disk directory with one file per application and worker.
* Files are deleted when applications and workers are removed.
@@ -37,64 +40,43 @@ private[spark] class FileSystemPersistenceEngine(
new File(dir).mkdir()
- override def addApplication(app: ApplicationInfo) {
- val appFile = new File(dir + File.separator + "app_" + app.id)
- serializeIntoFile(appFile, app)
- }
-
- override def removeApplication(app: ApplicationInfo) {
- new File(dir + File.separator + "app_" + app.id).delete()
- }
-
- override def addDriver(driver: DriverInfo) {
- val driverFile = new File(dir + File.separator + "driver_" + driver.id)
- serializeIntoFile(driverFile, driver)
- }
-
- override def removeDriver(driver: DriverInfo) {
- new File(dir + File.separator + "driver_" + driver.id).delete()
- }
-
- override def addWorker(worker: WorkerInfo) {
- val workerFile = new File(dir + File.separator + "worker_" + worker.id)
- serializeIntoFile(workerFile, worker)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(new File(dir + File.separator + name), obj)
}
- override def removeWorker(worker: WorkerInfo) {
- new File(dir + File.separator + "worker_" + worker.id).delete()
+ override def unpersist(name: String): Unit = {
+ new File(dir + File.separator + name).delete()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
- val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
- val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
- val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
- (apps, drivers, workers)
+ override def read[T: ClassTag](prefix: String) = {
+ val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix))
+ files.map(deserializeFromFile[T])
}
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
-
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
-
val out = new FileOutputStream(file)
- out.write(serialized)
- out.close()
+ try {
+ out.write(serialized)
+ } finally {
+ out.close()
+ }
}
- def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
+ private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
- dis.readFully(fileData)
- dis.close()
-
+ try {
+ dis.readFully(fileData)
+ } finally {
+ dis.close()
+ }
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index 4433a2ec29be6..cf77c86d760cf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -17,30 +17,27 @@
package org.apache.spark.deploy.master
-import akka.actor.{Actor, ActorRef}
-
-import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
+import org.apache.spark.annotation.DeveloperApi
/**
- * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
- * is the only Master serving requests.
- * In addition to the API provided, the LeaderElectionAgent will use of the following messages
- * to inform the Master of leader changes:
- * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
- * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
+ * :: DeveloperApi ::
+ *
+ * A LeaderElectionAgent tracks current master and is a common interface for all election Agents.
*/
-private[spark] trait LeaderElectionAgent extends Actor {
- // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
- val masterActor: ActorRef
+@DeveloperApi
+trait LeaderElectionAgent {
+ val masterActor: LeaderElectable
+ def stop() {} // to avoid noops in implementations.
}
-/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
-private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
- override def preStart() {
- masterActor ! ElectedLeader
- }
+@DeveloperApi
+trait LeaderElectable {
+ def electedLeader()
+ def revokedLeadership()
+}
- override def receive = {
- case _ =>
- }
+/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
+private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable)
+ extends LeaderElectionAgent {
+ masterActor.electedLeader()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index f98b531316a3d..7b32c505def9b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -30,6 +30,7 @@ import scala.util.Random
import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import akka.serialization.Serialization
import akka.serialization.SerializationExtension
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
@@ -50,7 +51,7 @@ private[spark] class Master(
port: Int,
webUiPort: Int,
val securityMgr: SecurityManager)
- extends Actor with ActorLogReceive with Logging {
+ extends Actor with ActorLogReceive with Logging with LeaderElectable {
import context.dispatcher // to use Akka's scheduler.schedule()
@@ -61,7 +62,6 @@ private[spark] class Master(
val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
- val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
val workers = new HashSet[WorkerInfo]
@@ -103,7 +103,7 @@ private[spark] class Master(
var persistenceEngine: PersistenceEngine = _
- var leaderElectionAgent: ActorRef = _
+ var leaderElectionAgent: LeaderElectionAgent = _
private var recoveryCompletionTask: Cancellable = _
@@ -130,23 +130,27 @@ private[spark] class Master(
masterMetricsSystem.start()
applicationMetricsSystem.start()
- persistenceEngine = RECOVERY_MODE match {
+ val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
- new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf)
+ val zkFactory =
+ new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system))
+ (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
- logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
- new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
+ val fsFactory =
+ new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system))
+ (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
+ case "CUSTOM" =>
+ val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
+ val factory = clazz.getConstructor(conf.getClass, Serialization.getClass)
+ .newInstance(conf, SerializationExtension(context.system))
+ .asInstanceOf[StandaloneRecoveryModeFactory]
+ (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
- new BlackHolePersistenceEngine()
+ (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this))
}
-
- leaderElectionAgent = RECOVERY_MODE match {
- case "ZOOKEEPER" =>
- context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf))
- case _ =>
- context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
- }
+ persistenceEngine = persistenceEngine_
+ leaderElectionAgent = leaderElectionAgent_
}
override def preRestart(reason: Throwable, message: Option[Any]) {
@@ -165,7 +169,15 @@ private[spark] class Master(
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
- context.stop(leaderElectionAgent)
+ leaderElectionAgent.stop()
+ }
+
+ override def electedLeader() {
+ self ! ElectedLeader
+ }
+
+ override def revokedLeadership() {
+ self ! RevokedLeadership
}
override def receiveWithLogging = {
@@ -341,7 +353,14 @@ private[spark] class Master(
case Some(workerInfo) =>
workerInfo.lastHeartbeat = System.currentTimeMillis()
case None =>
- logWarning("Got heartbeat from unregistered worker " + workerId)
+ if (workers.map(_.id).contains(workerId)) {
+ logWarning(s"Got heartbeat from unregistered worker $workerId." +
+ " Asking it to re-register.")
+ sender ! ReconnectWorker(masterUrl)
+ } else {
+ logWarning(s"Got heartbeat from unregistered worker $workerId." +
+ " This worker was never registered, so ignoring the heartbeat.")
+ }
}
}
@@ -714,8 +733,8 @@ private[spark] class Master(
try {
val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec)
- val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)",
- HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
+ val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf),
+ appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
replayBus.replay()
appIdToUI(app.id) = ui
webUi.attachSparkUI(ui)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index 4b0dbbe543d3f..e34bee7854292 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -27,6 +27,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
+ var propertiesFile: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_MASTER_HOST") != null) {
@@ -38,12 +39,16 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt
}
+
+ parse(args.toList)
+
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
if (conf.contains("spark.master.ui.port")) {
webUiPort = conf.get("spark.master.ui.port").toInt
}
- parse(args.toList)
-
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -63,7 +68,11 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
webUiPort = value
parse(tail)
- case ("--help" | "-h") :: tail =>
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
+ case ("--help") :: tail =>
printUsageAndExit(0)
case Nil => {}
@@ -83,7 +92,9 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) {
" -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
- " --webui-port PORT Port for web UI (default: 8080)")
+ " --webui-port PORT Port for web UI (default: 8080)\n" +
+ " --properties-file FILE Path to a custom Spark properties file.\n" +
+ " Default is conf/spark-defaults.conf.")
System.exit(exitCode)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index e3640ea4f7e64..2e0e1e7036ac8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -17,6 +17,10 @@
package org.apache.spark.deploy.master
+import org.apache.spark.annotation.DeveloperApi
+
+import scala.reflect.ClassTag
+
/**
* Allows Master to persist any state that is necessary in order to recover from a failure.
* The following semantics are required:
@@ -25,36 +29,70 @@ package org.apache.spark.deploy.master
* Given these two requirements, we will have all apps and workers persisted, but
* we might not have yet deleted apps or workers that finished (so their liveness must be verified
* during recovery).
+ *
+ * The implementation of this trait defines how name-object pairs are stored or retrieved.
*/
-private[spark] trait PersistenceEngine {
- def addApplication(app: ApplicationInfo)
+@DeveloperApi
+trait PersistenceEngine {
- def removeApplication(app: ApplicationInfo)
+ /**
+ * Defines how the object is serialized and persisted. Implementation will
+ * depend on the store used.
+ */
+ def persist(name: String, obj: Object)
- def addWorker(worker: WorkerInfo)
+ /**
+ * Defines how the object referred by its name is removed from the store.
+ */
+ def unpersist(name: String)
- def removeWorker(worker: WorkerInfo)
+ /**
+ * Gives all objects, matching a prefix. This defines how objects are
+ * read/deserialized back.
+ */
+ def read[T: ClassTag](prefix: String): Seq[T]
- def addDriver(driver: DriverInfo)
+ final def addApplication(app: ApplicationInfo): Unit = {
+ persist("app_" + app.id, app)
+ }
- def removeDriver(driver: DriverInfo)
+ final def removeApplication(app: ApplicationInfo): Unit = {
+ unpersist("app_" + app.id)
+ }
+
+ final def addWorker(worker: WorkerInfo): Unit = {
+ persist("worker_" + worker.id, worker)
+ }
+
+ final def removeWorker(worker: WorkerInfo): Unit = {
+ unpersist("worker_" + worker.id)
+ }
+
+ final def addDriver(driver: DriverInfo): Unit = {
+ persist("driver_" + driver.id, driver)
+ }
+
+ final def removeDriver(driver: DriverInfo): Unit = {
+ unpersist("driver_" + driver.id)
+ }
/**
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo])
+ final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+ (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ }
def close() {}
}
private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
- override def addApplication(app: ApplicationInfo) {}
- override def removeApplication(app: ApplicationInfo) {}
- override def addWorker(worker: WorkerInfo) {}
- override def removeWorker(worker: WorkerInfo) {}
- override def addDriver(driver: DriverInfo) {}
- override def removeDriver(driver: DriverInfo) {}
-
- override def readPersistedData() = (Nil, Nil, Nil)
+
+ override def persist(name: String, obj: Object): Unit = {}
+
+ override def unpersist(name: String): Unit = {}
+
+ override def read[T: ClassTag](name: String): Seq[T] = Nil
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
new file mode 100644
index 0000000000000..1096eb0368357
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.serialization.Serialization
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * ::DeveloperApi::
+ *
+ * Implementation of this class can be plugged in as recovery mode alternative for Spark's
+ * Standalone mode.
+ *
+ */
+@DeveloperApi
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) {
+
+ /**
+ * PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
+ * is handled for recovery.
+ *
+ */
+ def createPersistenceEngine(): PersistenceEngine
+
+ /**
+ * Create an instance of LeaderAgent that decides who gets elected as master.
+ */
+ def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent
+}
+
+/**
+ * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
+ * recovery is made by restoring from filesystem.
+ */
+private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+ extends StandaloneRecoveryModeFactory(conf, serializer) with Logging {
+ val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
+
+ def createPersistenceEngine() = {
+ logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
+ new FileSystemPersistenceEngine(RECOVERY_DIR, serializer)
+ }
+
+ def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master)
+}
+
+private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+ extends StandaloneRecoveryModeFactory(conf, serializer) {
+ def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer)
+
+ def createLeaderElectionAgent(master: LeaderElectable) =
+ new ZooKeeperLeaderElectionAgent(master, conf)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index c5fa9cf7d7c2d..473ddc23ff0f3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
@@ -50,7 +51,7 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
- private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 285f9b014e291..8eaa0ad948519 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -24,9 +24,8 @@ import org.apache.spark.deploy.master.MasterMessages._
import org.apache.curator.framework.CuratorFramework
import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
-private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
- masterUrl: String, conf: SparkConf)
- extends LeaderElectionAgent with LeaderLatchListener with Logging {
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable,
+ conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging {
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
@@ -34,30 +33,21 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
private var leaderLatch: LeaderLatch = _
private var status = LeadershipStatus.NOT_LEADER
- override def preStart() {
+ start()
+ def start() {
logInfo("Starting ZooKeeper LeaderElection agent")
zk = SparkCuratorUtil.newClient(conf)
leaderLatch = new LeaderLatch(zk, WORKING_DIR)
leaderLatch.addListener(this)
-
leaderLatch.start()
}
- override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
- logError("LeaderElectionAgent failed...", reason)
- super.preRestart(reason, message)
- }
-
- override def postStop() {
+ override def stop() {
leaderLatch.close()
zk.close()
}
- override def receive = {
- case _ =>
- }
-
override def isLeader() {
synchronized {
// could have lost leadership by now.
@@ -85,10 +75,10 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
def updateLeadershipStatus(isLeader: Boolean) {
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
status = LeadershipStatus.LEADER
- masterActor ! ElectedLeader
+ masterActor.electedLeader()
} else if (!isLeader && status == LeadershipStatus.LEADER) {
status = LeadershipStatus.NOT_LEADER
- masterActor ! RevokedLeadership
+ masterActor.revokedLeadership()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 834dfedee52ce..e11ac031fb9c6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,15 +17,18 @@
package org.apache.spark.deploy.master
+import akka.serialization.Serialization
+
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
-import akka.serialization.Serialization
import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
-class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
+
+private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
extends PersistenceEngine
with Logging
{
@@ -34,52 +37,31 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
SparkCuratorUtil.mkdir(zk, WORKING_DIR)
- override def addApplication(app: ApplicationInfo) {
- serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
- }
- override def removeApplication(app: ApplicationInfo) {
- zk.delete().forPath(WORKING_DIR + "/app_" + app.id)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(WORKING_DIR + "/" + name, obj)
}
- override def addDriver(driver: DriverInfo) {
- serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver)
+ override def unpersist(name: String): Unit = {
+ zk.delete().forPath(WORKING_DIR + "/" + name)
}
- override def removeDriver(driver: DriverInfo) {
- zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id)
- }
-
- override def addWorker(worker: WorkerInfo) {
- serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
- }
-
- override def removeWorker(worker: WorkerInfo) {
- zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id)
+ override def read[T: ClassTag](prefix: String) = {
+ val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix))
+ file.map(deserializeFromFile[T]).flatten
}
override def close() {
zk.close()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted
- val appFiles = sortedFiles.filter(_.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten
- val driverFiles = sortedFiles.filter(_.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten
- val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten
- (apps, drivers, workers)
- }
-
private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
}
- def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = {
+ def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 2e9be2a180c68..28e9662db5da9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -20,6 +20,8 @@ package org.apache.spark.deploy.worker
import java.io.{File, FileOutputStream, InputStream, IOException}
import java.lang.System._
+import scala.collection.Map
+
import org.apache.spark.Logging
import org.apache.spark.deploy.Command
import org.apache.spark.util.Utils
@@ -29,7 +31,29 @@ import org.apache.spark.util.Utils
*/
private[spark]
object CommandUtils extends Logging {
- def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+
+ /**
+ * Build a ProcessBuilder based on the given parameters.
+ * The `env` argument is exposed for testing.
+ */
+ def buildProcessBuilder(
+ command: Command,
+ memory: Int,
+ sparkHome: String,
+ substituteArguments: String => String,
+ classPaths: Seq[String] = Seq[String](),
+ env: Map[String, String] = sys.env): ProcessBuilder = {
+ val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env)
+ val commandSeq = buildCommandSeq(localCommand, memory, sparkHome)
+ val builder = new ProcessBuilder(commandSeq: _*)
+ val environment = builder.environment()
+ for ((key, value) <- localCommand.environment) {
+ environment.put(key, value)
+ }
+ builder
+ }
+
+ private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java")
// SPARK-698: do not call the run.cmd script, as process.destroy()
@@ -38,11 +62,41 @@ object CommandUtils extends Logging {
command.arguments
}
+ /**
+ * Build a command based on the given one, taking into account the local environment
+ * of where this command is expected to run, substitute any placeholders, and append
+ * any extra class paths.
+ */
+ private def buildLocalCommand(
+ command: Command,
+ substituteArguments: String => String,
+ classPath: Seq[String] = Seq[String](),
+ env: Map[String, String]): Command = {
+ val libraryPathName = Utils.libraryPathEnvName
+ val libraryPathEntries = command.libraryPathEntries
+ val cmdLibraryPath = command.environment.get(libraryPathName)
+
+ val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) {
+ val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName)
+ command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator)))
+ } else {
+ command.environment
+ }
+
+ Command(
+ command.mainClass,
+ command.arguments.map(substituteArguments),
+ newEnvironment,
+ command.classPathEntries ++ classPath,
+ Seq[String](), // library path already captured in environment variable
+ command.javaOpts)
+ }
+
/**
* Attention: this must always be aligned with the environment variables in the run scripts and
* the way the JAVA_OPTS are assembled there.
*/
- def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M")
// Exists for backwards compatibility with older Spark versions
@@ -53,14 +107,6 @@ object CommandUtils extends Logging {
logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.")
}
- val libraryOpts =
- if (command.libraryPathEntries.size > 0) {
- val joined = command.libraryPathEntries.mkString(File.pathSeparator)
- Seq(s"-Djava.library.path=$joined")
- } else {
- Seq()
- }
-
// Figure out our classpath with the external compute-classpath script
val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
val classPath = Utils.executeAndGetOutput(
@@ -71,7 +117,7 @@ object CommandUtils extends Logging {
val javaVersion = System.getProperty("java.version")
val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None
Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++
- permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
+ permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 9f9911762505a..28cab36c7b9e2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
import scala.collection.Map
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileUtil, Path}
@@ -76,17 +76,9 @@ private[spark] class DriverRunner(
// Make sure user application jar is on the classpath
// TODO: If we add ability to submit multiple jars they should also be added here
- val classPath = driverDesc.command.classPathEntries ++ Seq(s"$localJarFilename")
- val newCommand = Command(
- driverDesc.command.mainClass,
- driverDesc.command.arguments.map(substituteVariables),
- driverDesc.command.environment,
- classPath,
- driverDesc.command.libraryPathEntries,
- driverDesc.command.javaOpts)
- val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem,
- sparkHome.getAbsolutePath)
- launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise)
+ val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem,
+ sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename))
+ launchDriver(builder, driverDir, driverDesc.supervise)
}
catch {
case e: Exception => finalException = Some(e)
@@ -165,11 +157,8 @@ private[spark] class DriverRunner(
localJarFilename
}
- private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File,
- supervise: Boolean) {
- val builder = new ProcessBuilder(command: _*).directory(baseDir)
- envVars.map{ case(k,v) => builder.environment().put(k, v) }
-
+ private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) {
+ builder.directory(baseDir)
def initialize(process: Process) = {
// Redirect stdout and stderr to files
val stdout = new File(baseDir, "stdout")
@@ -177,8 +166,8 @@ private[spark] class DriverRunner(
val stderr = new File(baseDir, "stderr")
val header = "Launch Command: %s\n%s\n\n".format(
- command.mkString("\"", "\" \"", "\""), "=" * 40)
- Files.append(header, stderr, Charsets.UTF_8)
+ builder.command.mkString("\"", "\" \"", "\""), "=" * 40)
+ Files.append(header, stderr, UTF_8)
CommandUtils.redirectStream(process.getErrorStream, stderr)
}
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 71650cd773bcf..8ba6a01bbcb97 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -19,8 +19,10 @@ package org.apache.spark.deploy.worker
import java.io._
+import scala.collection.JavaConversions._
+
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.spark.{SparkConf, Logging}
@@ -111,36 +113,25 @@ private[spark] class ExecutorRunner(
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => host
case "{{CORES}}" => cores.toString
+ case "{{APP_ID}}" => appId
case other => other
}
- def getCommandSeq = {
- val command = Command(
- appDesc.command.mainClass,
- appDesc.command.arguments.map(substituteVariables) ++ Seq(appId),
- appDesc.command.environment,
- appDesc.command.classPathEntries,
- appDesc.command.libraryPathEntries,
- appDesc.command.javaOpts)
- CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath)
- }
-
/**
* Download and run the executor described in our ApplicationDescription
*/
def fetchAndRunExecutor() {
try {
// Launch the process
- val command = getCommandSeq
+ val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory,
+ sparkHome.getAbsolutePath, substituteVariables)
+ val command = builder.command()
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
- val builder = new ProcessBuilder(command: _*).directory(executorDir)
- val env = builder.environment()
- for ((key, value) <- appDesc.command.environment) {
- env.put(key, value)
- }
+
+ builder.directory(executorDir)
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
- env.put("SPARK_LAUNCH_WITH_SCALA", "0")
+ builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0")
process = builder.start()
val header = "Spark Executor Command: %s\n%s\n\n".format(
command.mkString("\"", "\" \"", "\""), "=" * 40)
@@ -150,7 +141,7 @@ private[spark] class ExecutorRunner(
stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
val stderr = new File(executorDir, "stderr")
- Files.write(header, stderr, Charsets.UTF_8)
+ Files.write(header, stderr, UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
state = ExecutorState.RUNNING
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
new file mode 100644
index 0000000000000..b9798963bab0a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.SaslRpcHandler
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+
+/**
+ * Provides a server from which Executors can read shuffle files (rather than reading directly from
+ * each other), to provide uninterrupted access to the files in the face of executors being turned
+ * off or killed.
+ *
+ * Optionally requires SASL authentication in order to read. See [[SecurityManager]].
+ */
+private[worker]
+class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
+ extends Logging {
+
+ private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
+ private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
+ private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
+
+ private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
+ private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
+ private val transportContext: TransportContext = {
+ val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
+ new TransportContext(transportConf, handler)
+ }
+
+ private var server: TransportServer = _
+
+ /** Starts the external shuffle service if the user has configured us to. */
+ def startIfEnabled() {
+ if (enabled) {
+ require(server == null, "Shuffle server already started")
+ logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
+ server = transportContext.createServer(port)
+ }
+ }
+
+ def stop() {
+ if (enabled && server != null) {
+ server.close()
+ server = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 3b13f43a1868c..eb11163538b20 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -20,16 +20,16 @@ package org.apache.spark.deploy.worker
import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{UUID, Date}
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.util.Random
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.commons.io.FileUtils
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
@@ -65,8 +65,22 @@ private[spark] class Worker(
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
- val REGISTRATION_TIMEOUT = 20.seconds
- val REGISTRATION_RETRIES = 3
+ // Model retries to connect to the master, after Hadoop's model.
+ // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds)
+ // Afterwards, the next 10 attempts are between 30 and 90 seconds.
+ // A bit of randomness is introduced so that not all of the workers attempt to reconnect at
+ // the same time.
+ val INITIAL_REGISTRATION_RETRIES = 6
+ val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10
+ val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500
+ val REGISTRATION_RETRY_FUZZ_MULTIPLIER = {
+ val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits)
+ randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND
+ }
+ val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 *
+ REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
+ val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60
+ * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false)
// How often worker will clean up old app folders
@@ -96,6 +110,9 @@ private[spark] class Worker(
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]
+ // The shuffle service is not actually started unless configured.
+ val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
+
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else host
@@ -104,6 +121,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
+ var connectionAttemptCount = 0
val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)
@@ -138,6 +156,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
@@ -157,9 +176,12 @@ private[spark] class Worker(
throw new SparkException("Invalid spark URL: " + x)
}
connected = true
+ // Cancel any outstanding re-registration attempts because we found a new master
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = None
}
- def tryRegisterAllMasters() {
+ private def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
@@ -167,26 +189,80 @@ private[spark] class Worker(
}
}
- def registerWithMaster() {
- tryRegisterAllMasters()
- var retries = 0
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
- Utils.tryOrExit {
- retries += 1
- if (registered) {
- registrationRetryTimer.foreach(_.cancel())
- } else if (retries >= REGISTRATION_RETRIES) {
- logError("All masters are unresponsive! Giving up.")
- System.exit(1)
- } else {
- tryRegisterAllMasters()
+ /**
+ * Re-register with the master because a network failure or a master failure has occurred.
+ * If the re-registration attempt threshold is exceeded, the worker exits with error.
+ * Note that for thread-safety this should only be called from the actor.
+ */
+ private def reregisterWithMaster(): Unit = {
+ Utils.tryOrExit {
+ connectionAttemptCount += 1
+ if (registered) {
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = None
+ } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
+ logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
+ /**
+ * Re-register with the active master this worker has been communicating with. If there
+ * is none, then it means this worker is still bootstrapping and hasn't established a
+ * connection with a master yet, in which case we should re-register with all masters.
+ *
+ * It is important to re-register only with the active master during failures. Otherwise,
+ * if the worker unconditionally attempts to re-register with all masters, the following
+ * race condition may arise and cause a "duplicate worker" error detailed in SPARK-4592:
+ *
+ * (1) Master A fails and Worker attempts to reconnect to all masters
+ * (2) Master B takes over and notifies Worker
+ * (3) Worker responds by registering with Master B
+ * (4) Meanwhile, Worker's previous reconnection attempt reaches Master B,
+ * causing the same Worker to register with Master B twice
+ *
+ * Instead, if we only register with the known active master, we can assume that the
+ * old master must have died because another master has taken over. Note that this is
+ * still not safe if the old master recovers within this interval, but this is a much
+ * less likely scenario.
+ */
+ if (master != null) {
+ master ! RegisterWorker(
+ workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
+ } else {
+ // We are retrying the initial registration
+ tryRegisterAllMasters()
+ }
+ // We have exceeded the initial registration retry threshold
+ // All retries from now on should use a higher interval
+ if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = Some {
+ context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL,
+ PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
}
}
+ } else {
+ logError("All masters are unresponsive! Giving up.")
+ System.exit(1)
}
}
}
+ def registerWithMaster() {
+ // DisassociatedEvent may be triggered multiple times, so don't attempt registration
+ // if there are outstanding registration attempts scheduled.
+ registrationRetryTimer match {
+ case None =>
+ registered = false
+ tryRegisterAllMasters()
+ connectionAttemptCount = 0
+ registrationRetryTimer = Some {
+ context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL,
+ INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
+ }
+ case Some(_) =>
+ logInfo("Not spawning another attempt to register with the master, since there is an" +
+ " attempt scheduled already.")
+ }
+ }
+
override def receiveWithLogging = {
case RegisteredWorker(masterUrl, masterWebUiUrl) =>
logInfo("Successfully registered with master " + masterUrl)
@@ -244,6 +320,10 @@ private[spark] class Worker(
System.exit(1)
}
+ case ReconnectWorker(masterUrl) =>
+ logInfo(s"Master with url $masterUrl requested this worker to reconnect.")
+ registerWithMaster()
+
case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) =>
if (masterUrl != activeMasterUrl) {
logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
@@ -355,17 +435,21 @@ private[spark] class Worker(
logInfo(s"$x Disassociated !")
masterDisconnected()
- case RequestWorkerState => {
+ case RequestWorkerState =>
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, drivers.values.toList,
finishedDrivers.values.toList, activeMasterUrl, cores, memory,
coresUsed, memoryUsed, activeMasterWebUiUrl)
- }
+
+ case ReregisterWithMaster =>
+ reregisterWithMaster()
+
}
- def masterDisconnected() {
+ private def masterDisconnected() {
logError("Connection to master failed! Waiting for master to reconnect...")
connected = false
+ registerWithMaster()
}
def generateWorkerId(): String = {
@@ -377,6 +461,7 @@ private[spark] class Worker(
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
+ shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
@@ -399,7 +484,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
- workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ workDir: String,
+ workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 1e295aaa48c30..019cd70f2a229 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -33,6 +33,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
var memory = inferDefaultMemory()
var masters: Array[String] = null
var workDir: String = null
+ var propertiesFile: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_WORKER_PORT") != null) {
@@ -41,21 +42,27 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
if (System.getenv("SPARK_WORKER_CORES") != null) {
cores = System.getenv("SPARK_WORKER_CORES").toInt
}
- if (System.getenv("SPARK_WORKER_MEMORY") != null) {
- memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY"))
+ if (conf.getenv("SPARK_WORKER_MEMORY") != null) {
+ memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY"))
}
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
}
- if (conf.contains("spark.worker.ui.port")) {
- webUiPort = conf.get("spark.worker.ui.port").toInt
- }
if (System.getenv("SPARK_WORKER_DIR") != null) {
workDir = System.getenv("SPARK_WORKER_DIR")
}
parse(args.toList)
+ // This mutates the SparkConf, so all accesses to it must be made after this line
+ propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile)
+
+ if (conf.contains("spark.worker.ui.port")) {
+ webUiPort = conf.get("spark.worker.ui.port").toInt
+ }
+
+ checkWorkerMemory()
+
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
@@ -87,7 +94,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
webUiPort = value
parse(tail)
- case ("--help" | "-h") :: tail =>
+ case ("--properties-file") :: value :: tail =>
+ propertiesFile = value
+ parse(tail)
+
+ case ("--help") :: tail =>
printUsageAndExit(0)
case value :: tail =>
@@ -122,7 +133,9 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
" -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
- " --webui-port PORT Port for web UI (default: 8081)")
+ " --webui-port PORT Port for web UI (default: 8081)\n" +
+ " --properties-file FILE Path to a custom Spark properties file.\n" +
+ " Default is conf/spark-defaults.conf.")
System.exit(exitCode)
}
@@ -153,4 +166,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, 512)
}
+
+ def checkWorkerMemory(): Unit = {
+ if (memory <= 0) {
+ val message = "Memory can't be 0, missing a M or G on the end of the memory specification?"
+ throw new IllegalStateException(message)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
index 6d0d0bbe5ecec..63a8ac817b618 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala
@@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String)
case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) =>
logInfo(s"Successfully connected to $workerUrl")
- case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound)
+ case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _)
if isWorker(remoteAddress) =>
// These logs may not be seen if the worker (and associated pipe) has died
logError(s"Could not initialize connection to worker $workerUrl. Exiting.")
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 06061edfc0844..5f46f3b1f085e 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import scala.concurrent.Await
-import akka.actor.{Actor, ActorSelection, Props}
+import akka.actor.{Actor, ActorSelection, ActorSystem, Props}
import akka.pattern.Patterns
import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
@@ -38,7 +38,8 @@ private[spark] class CoarseGrainedExecutorBackend(
executorId: String,
hostPort: String,
cores: Int,
- sparkProperties: Seq[(String, String)])
+ sparkProperties: Seq[(String, String)],
+ actorSystem: ActorSystem)
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
@@ -56,9 +57,9 @@ private[spark] class CoarseGrainedExecutorBackend(
override def receiveWithLogging = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
- // Make this host instead of hostPort ?
- executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties,
- false)
+ val (hostname, _) = Utils.parseHostPort(hostPort)
+ executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false,
+ actorSystem)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -130,12 +131,13 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Create a new ActorSystem using driver's Spark properties to run the backend.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
+ SparkEnv.executorActorSystemName,
+ hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, props),
+ driverUrl, executorId, sparkHostPort, cores, props, actorSystem),
name = "Executor")
workerUrl.foreach { url =>
actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
@@ -152,6 +154,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
"Usage: CoarseGrainedExecutorBackend " +
" [] ")
System.exit(1)
+
+ // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode)
+ // and CoarseMesosSchedulerBackend (for mesos mode).
case 5 =>
run(args(0), args(1), args(2), args(3).toInt, args(4), None)
case x if x > 5 =>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9bbfcdc4a0b6e..835157fc520aa 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,21 +26,26 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import akka.actor.{Props, ActorSystem}
+
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
+ * In coarse-grained mode, an existing actor system is provided.
*/
private[spark] class Executor(
executorId: String,
slaveHostname: String,
properties: Seq[(String, String)],
- isLocal: Boolean = false)
+ numCores: Int,
+ isLocal: Boolean = false,
+ actorSystem: ActorSystem = null)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -68,25 +73,31 @@ private[spark] class Executor(
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(ExecutorUncaughtExceptionHandler)
+ Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
}
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- conf.set("spark.executor.id", "executor." + executorId)
+ conf.set("spark.executor.id", executorId)
private val env = {
if (!isLocal) {
- val _env = SparkEnv.create(conf, executorId, slaveHostname, 0,
- isDriver = false, isLocal = false)
+ val port = conf.getInt("spark.executor.port", 0)
+ val _env = SparkEnv.createExecutorEnv(
+ conf, executorId, slaveHostname, port, numCores, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
+ _env.blockManager.initialize(conf.getAppId)
_env
} else {
SparkEnv.get
}
}
+ // Create an actor for receiving RPCs from the driver
+ private val executorActor = env.actorSystem.actorOf(
+ Props(new ExecutorActor(executorId)), "ExecutorActor")
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -99,6 +110,9 @@ private[spark] class Executor(
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+ // Limit of bytes for total size of results (default is 1GB)
+ private val maxResultSize = Utils.getMaxResultSize(conf)
+
// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
@@ -123,6 +137,7 @@ private[spark] class Executor(
def stop() {
env.metricsSystem.report()
+ env.actorSystem.stop(executorActor)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -147,8 +162,7 @@ private[spark] class Executor(
}
override def run() {
- val startTime = System.currentTimeMillis()
- SparkEnv.set(env)
+ val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
@@ -158,7 +172,6 @@ private[spark] class Executor(
val startGCTime = gcTime
try {
- SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
@@ -194,7 +207,7 @@ private[spark] class Executor(
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.executorDeserializeTime = taskStart - startTime
+ m.executorDeserializeTime = taskStart - deserializeStartTime
m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
m.resultSerializationTime = afterSerialization - beforeSerialization
@@ -207,25 +220,27 @@ private[spark] class Executor(
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
- val (serializedResult, directSend) = {
- if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+ val serializedResult = {
+ if (maxResultSize > 0 && resultSize > maxResultSize) {
+ logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
+ s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
+ s"dropping it.")
+ ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
+ } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
- (ser.serialize(new IndirectTaskResult[Any](blockId)), false)
+ logInfo(
+ s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
+ ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
- (serializedDirectResult, true)
+ logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
+ serializedDirectResult
}
}
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
- if (directSend) {
- logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
- } else {
- logInfo(
- s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
- }
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
@@ -249,13 +264,13 @@ private[spark] class Executor(
m.executorRunTime = serviceTime
m.jvmGCTime = gcTime - startGCTime
}
- val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
+ val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
- ExecutorUncaughtExceptionHandler.uncaughtException(t)
+ SparkUncaughtExceptionHandler.uncaughtException(t)
}
}
} finally {
@@ -319,19 +334,21 @@ private[spark] class Executor(
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+ lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager,
- hadoopConf)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager,
- hadoopConf)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
new file mode 100644
index 0000000000000..41925f7e97e84
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import akka.actor.Actor
+import org.apache.spark.Logging
+
+import org.apache.spark.util.{Utils, ActorLogReceive}
+
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * Actor that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case TriggerThreadDump =>
+ sender ! Utils.getThreadDump()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
index 38be2c58b333f..52862ae0ca5e4 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -17,6 +17,8 @@
package org.apache.spark.executor
+import org.apache.spark.util.SparkExitCode._
+
/**
* These are exit codes that executors should use to provide the master with information about
* executor failures assuming that cluster management framework can capture the exit codes (but
@@ -27,16 +29,6 @@ package org.apache.spark.executor
*/
private[spark]
object ExecutorExitCode {
- /** The default uncaught exception handler was reached. */
- val UNCAUGHT_EXCEPTION = 50
-
- /** The default uncaught exception handler was called and an exception was encountered while
- logging the exception. */
- val UNCAUGHT_EXCEPTION_TWICE = 51
-
- /** The default uncaught exception handler was reached, and the uncaught exception was an
- OutOfMemoryError. */
- val OOM = 52
/** DiskStore failed to create a local temporary directory after many attempts. */
val DISK_STORE_FAILED_TO_CREATE_DIR = 53
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala
deleted file mode 100644
index b0e984c03964c..0000000000000
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
-
-/**
- * The default uncaught exception handler for Executors terminates the whole process, to avoid
- * getting into a bad state indefinitely. Since Executors are relatively lightweight, it's better
- * to fail fast when things go wrong.
- */
-private[spark] object ExecutorUncaughtExceptionHandler
- extends Thread.UncaughtExceptionHandler with Logging {
-
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
- }
- }
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
- }
- }
-
- def uncaughtException(exception: Throwable) {
- uncaughtException(Thread.currentThread(), exception)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index bca0b152268ad..f15e6bc33fb41 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -19,6 +19,8 @@ package org.apache.spark.executor
import java.nio.ByteBuffer
+import scala.collection.JavaConversions._
+
import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary}
import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
@@ -50,14 +52,23 @@ private[spark] class MesosExecutorBackend
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
- logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
+
+ // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend.
+ val cpusPerTask = executorInfo.getResourcesList
+ .find(_.getName == "cpus")
+ .map(_.getScalar.getValue.toInt)
+ .getOrElse(0)
+ val executorId = executorInfo.getExecutorId.getValue
+
+ logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus")
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
executor = new Executor(
- executorInfo.getExecutorId.getValue,
+ executorId,
slaveInfo.getHostname,
- properties)
+ properties,
+ cpusPerTask)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 3e49b6235aff3..51b5328cb4c8f 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -82,6 +82,12 @@ class TaskMetrics extends Serializable {
*/
var inputMetrics: Option[InputMetrics] = None
+ /**
+ * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
+ * data was written are stored here.
+ */
+ var outputMetrics: Option[OutputMetrics] = None
+
/**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
* This includes read metrics aggregated over all the task's shuffle dependencies.
@@ -157,6 +163,16 @@ object DataReadMethod extends Enumeration with Serializable {
val Memory, Disk, Hadoop, Network = Value
}
+/**
+ * :: DeveloperApi ::
+ * Method by which output data was written.
+ */
+@DeveloperApi
+object DataWriteMethod extends Enumeration with Serializable {
+ type DataWriteMethod = Value
+ val Hadoop = Value
+}
+
/**
* :: DeveloperApi ::
* Metrics about reading input data.
@@ -169,6 +185,17 @@ case class InputMetrics(readMethod: DataReadMethod.Value) {
var bytesRead: Long = 0L
}
+/**
+ * :: DeveloperApi ::
+ * Metrics about writing output data.
+ */
+@DeveloperApi
+case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
+ /**
+ * Total bytes written
+ */
+ var bytesWritten: Long = 0L
+}
/**
* :: DeveloperApi ::
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
new file mode 100644
index 0000000000000..89b29af2000c8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+
+/**
+ * Custom Input Format for reading and splitting flat binary files that contain records,
+ * each of which are a fixed size in bytes. The fixed record size is specified through
+ * a parameter recordLength in the Hadoop configuration.
+ */
+private[spark] object FixedLengthBinaryInputFormat {
+ /** Property name to set in Hadoop JobConfs for record length */
+ val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength"
+
+ /** Retrieves the record length property from a Hadoop configuration */
+ def getRecordLength(context: JobContext): Int = {
+ context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt
+ }
+}
+
+private[spark] class FixedLengthBinaryInputFormat
+ extends FileInputFormat[LongWritable, BytesWritable] {
+
+ private var recordLength = -1
+
+ /**
+ * Override of isSplitable to ensure initial computation of the record length
+ */
+ override def isSplitable(context: JobContext, filename: Path): Boolean = {
+ if (recordLength == -1) {
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ }
+ if (recordLength <= 0) {
+ println("record length is less than 0, file cannot be split")
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * This input format overrides computeSplitSize() to make sure that each split
+ * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader
+ * will start at the first byte of a record, and the last byte will the last byte of a record.
+ */
+ override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = {
+ val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize)
+ // If the default size is less than the length of a record, make it equal to it
+ // Otherwise, make sure the split size is as close to possible as the default size,
+ // but still contains a complete set of records, with the first record
+ // starting at the first byte in the split and the last record ending with the last byte
+ if (defaultSize < recordLength) {
+ recordLength.toLong
+ } else {
+ (Math.floor(defaultSize / recordLength) * recordLength).toLong
+ }
+ }
+
+ /**
+ * Create a FixedLengthBinaryRecordReader
+ */
+ override def createRecordReader(split: InputSplit, context: TaskAttemptContext)
+ : RecordReader[LongWritable, BytesWritable] = {
+ new FixedLengthBinaryRecordReader
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
new file mode 100644
index 0000000000000..36a1e5d475f46
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.FSDataInputStream
+import org.apache.hadoop.io.compress.CompressionCodecFactory
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+
+/**
+ * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
+ * It uses the record length set in FixedLengthBinaryInputFormat to
+ * read one record at a time from the given InputSplit.
+ *
+ * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value.
+ *
+ * key = record index (Long)
+ * value = the record itself (BytesWritable)
+ */
+private[spark] class FixedLengthBinaryRecordReader
+ extends RecordReader[LongWritable, BytesWritable] {
+
+ private var splitStart: Long = 0L
+ private var splitEnd: Long = 0L
+ private var currentPosition: Long = 0L
+ private var recordLength: Int = 0
+ private var fileInputStream: FSDataInputStream = null
+ private var recordKey: LongWritable = null
+ private var recordValue: BytesWritable = null
+
+ override def close() {
+ if (fileInputStream != null) {
+ fileInputStream.close()
+ }
+ }
+
+ override def getCurrentKey: LongWritable = {
+ recordKey
+ }
+
+ override def getCurrentValue: BytesWritable = {
+ recordValue
+ }
+
+ override def getProgress: Float = {
+ splitStart match {
+ case x if x == splitEnd => 0.0.toFloat
+ case _ => Math.min(
+ ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0
+ ).toFloat
+ }
+ }
+
+ override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) {
+ // the file input
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+
+ // the byte position this fileSplit starts at
+ splitStart = fileSplit.getStart
+
+ // splitEnd byte marker that the fileSplit ends at
+ splitEnd = splitStart + fileSplit.getLength
+
+ // the actual file we will be reading from
+ val file = fileSplit.getPath
+ // job configuration
+ val job = context.getConfiguration
+ // check compression
+ val codec = new CompressionCodecFactory(job).getCodec(file)
+ if (codec != null) {
+ throw new IOException("FixedLengthRecordReader does not support reading compressed files")
+ }
+ // get the record length
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ // get the filesystem
+ val fs = file.getFileSystem(job)
+ // open the File
+ fileInputStream = fs.open(file)
+ // seek to the splitStart position
+ fileInputStream.seek(splitStart)
+ // set our current position
+ currentPosition = splitStart
+ }
+
+ override def nextKeyValue(): Boolean = {
+ if (recordKey == null) {
+ recordKey = new LongWritable()
+ }
+ // the key is a linear index of the record, given by the
+ // position the record starts divided by the record length
+ recordKey.set(currentPosition / recordLength)
+ // the recordValue to place the bytes into
+ if (recordValue == null) {
+ recordValue = new BytesWritable(new Array[Byte](recordLength))
+ }
+ // read a record if the currentPosition is less than the split end
+ if (currentPosition < splitEnd) {
+ // setup a buffer to store the record
+ val buffer = recordValue.getBytes
+ fileInputStream.readFully(buffer)
+ // update our current position
+ currentPosition = currentPosition + recordLength
+ // return true
+ return true
+ }
+ false
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
new file mode 100644
index 0000000000000..457472547fcbb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import scala.collection.JavaConversions._
+
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * A general format for reading whole files in as streams, byte arrays,
+ * or other functions to be added
+ */
+private[spark] abstract class StreamFileInputFormat[T]
+ extends CombineFileInputFormat[String, T]
+{
+ override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+
+ /**
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
+ * which is set through setMaxSplitSize
+ */
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
+ val files = listStatus(context)
+ val totalLen = files.map { file =>
+ if (file.isDir) 0L else file.getLen
+ }.sum
+
+ val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong
+ super.setMaxSplitSize(maxSplitSize)
+ }
+
+ def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T]
+
+}
+
+/**
+ * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]]
+ * to reading files out as streams
+ */
+private[spark] abstract class StreamBasedRecordReader[T](
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends RecordReader[String, T] {
+
+ // True means the current file has been processed, then skip it.
+ private var processed = false
+
+ private var key = ""
+ private var value: T = null.asInstanceOf[T]
+
+ override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+ override def close() = {}
+
+ override def getProgress = if (processed) 1.0f else 0.0f
+
+ override def getCurrentKey = key
+
+ override def getCurrentValue = value
+
+ override def nextKeyValue = {
+ if (!processed) {
+ val fileIn = new PortableDataStream(split, context, index)
+ value = parseStream(fileIn)
+ fileIn.close() // if it has not been open yet, close does nothing
+ key = fileIn.getPath
+ processed = true
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Parse the stream (and close it afterwards) and return the value as in type T
+ * @param inStream the stream to be read in
+ * @return the data formatted as
+ */
+ def parseStream(inStream: PortableDataStream): T
+}
+
+/**
+ * Reads the record in directly as a stream for other objects to manipulate and handle
+ */
+private[spark] class StreamRecordReader(
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends StreamBasedRecordReader[PortableDataStream](split, context, index) {
+
+ def parseStream(inStream: PortableDataStream): PortableDataStream = inStream
+}
+
+/**
+ * The format for the PortableDataStream files
+ */
+private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] {
+ override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = {
+ new CombineFileRecordReader[String, PortableDataStream](
+ split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader])
+ }
+}
+
+/**
+ * A class that allows DataStreams to be serialized and moved around by not creating them
+ * until they need to be read
+ * @note TaskAttemptContext is not serializable resulting in the confBytes construct
+ * @note CombineFileSplit is not serializable resulting in the splitBytes construct
+ */
+@Experimental
+class PortableDataStream(
+ @transient isplit: CombineFileSplit,
+ @transient context: TaskAttemptContext,
+ index: Integer)
+ extends Serializable {
+
+ // transient forces file to be reopened after being serialization
+ // it is also used for non-serializable classes
+
+ @transient private var fileIn: DataInputStream = null
+ @transient private var isOpen = false
+
+ private val confBytes = {
+ val baos = new ByteArrayOutputStream()
+ context.getConfiguration.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ private val splitBytes = {
+ val baos = new ByteArrayOutputStream()
+ isplit.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ @transient private lazy val split = {
+ val bais = new ByteArrayInputStream(splitBytes)
+ val nsplit = new CombineFileSplit()
+ nsplit.readFields(new DataInputStream(bais))
+ nsplit
+ }
+
+ @transient private lazy val conf = {
+ val bais = new ByteArrayInputStream(confBytes)
+ val nconf = new Configuration()
+ nconf.readFields(new DataInputStream(bais))
+ nconf
+ }
+ /**
+ * Calculate the path name independently of opening the file
+ */
+ @transient private lazy val path = {
+ val pathp = split.getPath(index)
+ pathp.toString
+ }
+
+ /**
+ * Create a new DataInputStream from the split and context
+ */
+ def open(): DataInputStream = {
+ if (!isOpen) {
+ val pathp = split.getPath(index)
+ val fs = pathp.getFileSystem(conf)
+ fileIn = fs.open(pathp)
+ isOpen = true
+ }
+ fileIn
+ }
+
+ /**
+ * Read the file as a byte array
+ */
+ def toArray(): Array[Byte] = {
+ open()
+ val innerBuffer = ByteStreams.toByteArray(fileIn)
+ close()
+ innerBuffer
+ }
+
+ /**
+ * Close the file (if it is currently open)
+ */
+ def close() = {
+ if (isOpen) {
+ try {
+ fileIn.close()
+ isOpen = false
+ } catch {
+ case ioe: java.io.IOException => // do nothing
+ }
+ }
+ }
+
+ def getPath(): String = path
+}
+
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 4cb450577796a..d3601cca832b2 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -19,14 +19,13 @@ package org.apache.spark.input
import scala.collection.JavaConversions._
+import org.apache.hadoop.conf.{Configuration, Configurable}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader
-import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
/**
* A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for
@@ -34,23 +33,31 @@ import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
* the value is the entire content of file.
*/
-private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] {
+private[spark] class WholeTextFileInputFormat
+ extends CombineFileInputFormat[String, String] with Configurable {
+
override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
+
override def createRecordReader(
split: InputSplit,
context: TaskAttemptContext): RecordReader[String, String] = {
- new CombineFileRecordReader[String, String](
- split.asInstanceOf[CombineFileSplit],
- context,
- classOf[WholeTextFileRecordReader])
+ val reader = new WholeCombineFileRecordReader(split, context)
+ reader.setConf(conf)
+ reader
}
/**
- * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API.
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API,
+ * which is set through setMaxSplitSize
*/
- def setMaxSplitSize(context: JobContext, minPartitions: Int) {
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
val files = listStatus(context)
val totalLen = files.map { file =>
if (file.isDir) 0L else file.getLen
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
index 3564ab2e2a162..6d59b24eb0596 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -17,11 +17,13 @@
package org.apache.spark.input
+import org.apache.hadoop.conf.{Configuration, Configurable}
import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.mapreduce.InputSplit
-import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader}
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
@@ -34,7 +36,13 @@ private[spark] class WholeTextFileRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
- extends RecordReader[String, String] {
+ extends RecordReader[String, String] with Configurable {
+
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
private[this] val path = split.getPath(index)
private[this] val fs = path.getFileSystem(context.getConfiguration)
@@ -57,8 +65,16 @@ private[spark] class WholeTextFileRecordReader(
override def nextKeyValue(): Boolean = {
if (!processed) {
+ val conf = new Configuration
+ val factory = new CompressionCodecFactory(conf)
+ val codec = factory.getCodec(path) // infers from file ext.
val fileIn = fs.open(path)
- val innerBuffer = ByteStreams.toByteArray(fileIn)
+ val innerBuffer = if (codec != null) {
+ ByteStreams.toByteArray(codec.createInputStream(fileIn))
+ } else {
+ ByteStreams.toByteArray(fileIn)
+ }
+
value = new Text(innerBuffer).toString
Closeables.close(fileIn, false)
processed = true
@@ -68,3 +84,33 @@ private[spark] class WholeTextFileRecordReader(
}
}
}
+
+
+/**
+ * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
+ * out in a key-value pair, where the key is the file path and the value is the entire content of
+ * the file.
+ */
+private[spark] class WholeCombineFileRecordReader(
+ split: InputSplit,
+ context: TaskAttemptContext)
+ extends CombineFileRecordReader[String, String](
+ split.asInstanceOf[CombineFileSplit],
+ context,
+ classOf[WholeTextFileRecordReader]
+ ) with Configurable {
+
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
+
+ override def initNextRecordReader(): Boolean = {
+ val r = super.initNextRecordReader()
+ if (r) {
+ this.curReader.asInstanceOf[WholeTextFileRecordReader].setConf(conf)
+ }
+ r
+ }
+}
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
similarity index 79%
rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 0c47afae54c8b..21b782edd2a9e 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -15,15 +15,24 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapred
+package org.apache.spark.mapred
-private[apache]
+import java.lang.reflect.Modifier
+
+import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext}
+
+private[spark]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
"org.apache.hadoop.mapred.JobContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
classOf[org.apache.hadoop.mapreduce.JobID])
+ // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private.
+ // Make it accessible if it's not in order to access it.
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
@@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil {
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
"org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
+ // See above
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
similarity index 96%
rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
index 1fca5729c6092..3340673f91156 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapreduce
+package org.apache.spark.mapreduce
import java.lang.{Boolean => JBoolean, Integer => JInteger}
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
-private[apache]
+private[spark]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index e0e91724271c8..1745d52c81923 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,20 +17,20 @@
package org.apache.spark.network
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+private[spark]
trait BlockDataManager {
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- def getBlockData(blockId: String): Option[ManagedBuffer]
+ def getBlockData(blockId: BlockId): ManagedBuffer
/**
* Put the block locally, using the given storage level.
*/
- def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit
+ def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
deleted file mode 100644
index 34acaa563ca58..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.util.EventListener
-
-
-/**
- * Listener callback interface for [[BlockTransferService.fetchBlocks]].
- */
-trait BlockFetchingListener extends EventListener {
-
- /**
- * Called once per successfully fetched block.
- */
- def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit
-
- /**
- * Called upon failures. For each failure, this is called only once (i.e. not once per block).
- */
- def onBlockFetchFailure(exception: Throwable): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 84d991fa6808c..dcbda5a8515dd 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -17,13 +17,19 @@
package org.apache.spark.network
-import scala.concurrent.{Await, Future}
-import scala.concurrent.duration.Duration
+import java.io.Closeable
+import java.nio.ByteBuffer
-import org.apache.spark.storage.StorageLevel
+import scala.concurrent.{Promise, Await, Future}
+import scala.concurrent.duration.Duration
+import org.apache.spark.Logging
+import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
+import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}
-abstract class BlockTransferService {
+private[spark]
+abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -34,7 +40,7 @@ abstract class BlockTransferService {
/**
* Tear down the transfer service.
*/
- def stop(): Unit
+ def close(): Unit
/**
* Port number the service is listening on, available only after [[init]] is invoked.
@@ -50,17 +56,15 @@ abstract class BlockTransferService {
* Fetch a sequence of blocks from a remote node asynchronously,
* available only after [[init]] is invoked.
*
- * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block,
- * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block).
- *
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- def fetchBlocks(
- hostName: String,
+ override def fetchBlocks(
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit
/**
@@ -69,7 +73,8 @@ abstract class BlockTransferService {
def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ execId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
@@ -78,40 +83,23 @@ abstract class BlockTransferService {
*
* It is also only available after [[init]] is invoked.
*/
- def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
+ def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
- val lock = new Object
- @volatile var result: Either[ManagedBuffer, Throwable] = null
- fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
- override def onBlockFetchFailure(exception: Throwable): Unit = {
- lock.synchronized {
- result = Right(exception)
- lock.notify()
+ val result = Promise[ManagedBuffer]()
+ fetchBlocks(host, port, execId, Array(blockId),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ result.failure(exception)
}
- }
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- lock.synchronized {
- result = Left(data)
- lock.notify()
- }
- }
- })
-
- // Sleep until result is no longer null
- lock.synchronized {
- while (result == null) {
- try {
- lock.wait()
- } catch {
- case e: InterruptedException =>
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ val ret = ByteBuffer.allocate(data.size.toInt)
+ ret.put(data.nioByteBuffer())
+ ret.flip()
+ result.success(new NioManagedBuffer(ret))
}
- }
- }
+ })
- result match {
- case Left(data) => data
- case Right(e) => throw e
- }
+ Await.result(result.future, Duration.Inf)
}
/**
@@ -123,9 +111,10 @@ abstract class BlockTransferService {
def uploadBlockSync(
hostname: String,
port: Int,
- blockId: String,
+ execId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
- Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
+ Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
deleted file mode 100644
index a4409181ec907..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io._
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.nio.channels.FileChannel.MapMode
-
-import scala.util.Try
-
-import com.google.common.io.ByteStreams
-import io.netty.buffer.{ByteBufInputStream, ByteBuf}
-
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
-
-
-/**
- * This interface provides an immutable view for data in the form of bytes. The implementation
- * should specify how the data is provided:
- *
- * - FileSegmentManagedBuffer: data backed by part of a file
- * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
- * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
- */
-sealed abstract class ManagedBuffer {
- // Note that all the methods are defined with parenthesis because their implementations can
- // have side effects (io operations).
-
- /** Number of bytes of the data. */
- def size: Long
-
- /**
- * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
- * returned ByteBuffer should not affect the content of this buffer.
- */
- def nioByteBuffer(): ByteBuffer
-
- /**
- * Exposes this buffer's data as an InputStream. The underlying implementation does not
- * necessarily check for the length of bytes read, so the caller is responsible for making sure
- * it does not go over the limit.
- */
- def inputStream(): InputStream
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a segment in a file
- */
-final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
- extends ManagedBuffer {
-
- override def size: Long = length
-
- override def nioByteBuffer(): ByteBuffer = {
- var channel: FileChannel = null
- try {
- channel = new RandomAccessFile(file, "r").getChannel
- channel.map(MapMode.READ_ONLY, offset, length)
- } catch {
- case e: IOException =>
- Try(channel.size).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- } finally {
- if (channel != null) {
- Utils.tryLog(channel.close())
- }
- }
- }
-
- override def inputStream(): InputStream = {
- var is: FileInputStream = null
- try {
- is = new FileInputStream(file)
- is.skip(offset)
- ByteStreams.limit(is, length)
- } catch {
- case e: IOException =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- Try(file.length).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- case e: Throwable =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- throw e
- }
- }
-
- override def toString: String = s"${getClass.getName}($file, $offset, $length)"
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
- */
-final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
-
- override def size: Long = buf.remaining()
-
- override def nioByteBuffer() = buf.duplicate()
-
- override def inputStream() = new ByteBufferInputStream(buf)
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
- */
-final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
-
- override def size: Long = buf.readableBytes()
-
- override def nioByteBuffer() = buf.nioBuffer()
-
- override def inputStream() = new ByteBufInputStream(buf)
-
- // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
- def release(): Unit = buf.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
new file mode 100644
index 0000000000000..b089da8596e2b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.Logging
+import org.apache.spark.network.BlockDataManager
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
+import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+
+/**
+ * Serves requests to open blocks by simply registering one chunk per block requested.
+ * Handles opening and uploading arbitrary BlockManager blocks.
+ *
+ * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
+ * is equivalent to one Spark-level shuffle block.
+ */
+class NettyBlockRpcServer(
+ serializer: Serializer,
+ blockManager: BlockDataManager)
+ extends RpcHandler with Logging {
+
+ private val streamManager = new OneForOneStreamManager()
+
+ override def receive(
+ client: TransportClient,
+ messageBytes: Array[Byte],
+ responseContext: RpcResponseCallback): Unit = {
+ val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
+ logTrace(s"Received request: $message")
+
+ message match {
+ case openBlocks: OpenBlocks =>
+ val blocks: Seq[ManagedBuffer] =
+ openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
+ val streamId = streamManager.registerStream(blocks.iterator)
+ logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
+ responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
+
+ case uploadBlock: UploadBlock =>
+ // StorageLevel is serialized as bytes using our JavaSerializer.
+ val level: StorageLevel =
+ serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+ val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
+ blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
+ responseContext.onSuccess(new Array[Byte](0))
+ }
+ }
+
+ override def getStreamManager(): StreamManager = streamManager
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
new file mode 100644
index 0000000000000..0027cbb0ff1fb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import scala.collection.JavaConversions._
+import scala.concurrent.{Future, Promise}
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
+import org.apache.spark.network.server._
+import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.protocol.UploadBlock
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
+ */
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int)
+ extends BlockTransferService {
+
+ // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
+ private val serializer = new JavaSerializer(conf)
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores)
+
+ private[this] var transportContext: TransportContext = _
+ private[this] var server: TransportServer = _
+ private[this] var clientFactory: TransportClientFactory = _
+ private[this] var appId: String = _
+
+ override def init(blockDataManager: BlockDataManager): Unit = {
+ val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+ val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ if (!authEnabled) {
+ (nettyRpcHandler, None)
+ } else {
+ (new SaslRpcHandler(nettyRpcHandler, securityManager),
+ Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+ }
+ }
+ transportContext = new TransportContext(transportConf, rpcHandler)
+ clientFactory = transportContext.createClientFactory(bootstrap.toList)
+ server = transportContext.createServer()
+ appId = conf.getAppId
+ logInfo("Server created on " + server.getPort)
+ }
+
+ override def fetchBlocks(
+ host: String,
+ port: Int,
+ execId: String,
+ blockIds: Array[String],
+ listener: BlockFetchingListener): Unit = {
+ logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
+ try {
+ val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
+ override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
+ }
+ }
+
+ val maxRetries = transportConf.maxIORetries()
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Exception while beginning fetchBlocks", e)
+ blockIds.foreach(listener.onBlockFetchFailure(_, e))
+ }
+ }
+
+ override def hostName: String = Utils.localHostName()
+
+ override def port: Int = server.getPort
+
+ override def uploadBlock(
+ hostname: String,
+ port: Int,
+ execId: String,
+ blockId: BlockId,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Future[Unit] = {
+ val result = Promise[Unit]()
+ val client = clientFactory.createClient(hostname, port)
+
+ // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
+ // using our binary protocol.
+ val levelBytes = serializer.newInstance().serialize(level).array()
+
+ // Convert or copy nio buffer into array in order to serialize it.
+ val nioBuffer = blockData.nioByteBuffer()
+ val array = if (nioBuffer.hasArray) {
+ nioBuffer.array()
+ } else {
+ val data = new Array[Byte](nioBuffer.remaining())
+ nioBuffer.get(data)
+ data
+ }
+
+ client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ logTrace(s"Successfully uploaded block $blockId")
+ result.success()
+ }
+ override def onFailure(e: Throwable): Unit = {
+ logError(s"Error while uploading block $blockId", e)
+ result.failure(e)
+ }
+ })
+
+ result.future
+ }
+
+ override def close(): Unit = {
+ server.close()
+ clientFactory.close()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
deleted file mode 100644
index b5870152c5a64..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.SparkConf
-
-/**
- * A central location that tracks all the settings we exposed to users.
- */
-private[spark]
-class NettyConfig(conf: SparkConf) {
-
- /** Port the server listens on. Default to a random port. */
- private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0)
-
- /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
- private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase
-
- /** Connect timeout in secs. Default 60 secs. */
- private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
-
- /**
- * Percentage of the desired amount of time spent for I/O in the child event loops.
- * Only applicable in nio and epoll.
- */
- private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
-
- /** Requested maximum length of the queue of incoming connections. */
- private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
-
- /**
- * Receive buffer size (SO_RCVBUF).
- * Note: the optimal size for receive buffer and send buffer should be
- * latency * network_bandwidth.
- * Assuming latency = 1ms, network_bandwidth = 10Gbps
- * buffer size should be ~ 1.25MB
- */
- private[netty] val receiveBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-
- /** Send buffer size (SO_SNDBUF). */
- private[netty] val sendBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
deleted file mode 100644
index 0d7695072a7b1..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.storage.{BlockId, FileSegment}
-
-trait PathResolver {
- /** Get the file segment in which the given block resides. */
- def getBlockLocation(blockId: BlockId): FileSegment
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
new file mode 100644
index 0000000000000..cef203006d685
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.{TransportConf, ConfigProvider}
+
+/**
+ * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor,
+ * Driver, or a standalone shuffle service) into a TransportConf with details on our environment
+ * like the number of cores that are allocated to this JVM.
+ */
+object SparkTransportConf {
+ /**
+ * Specifies an upper bound on the number of Netty threads that Spark requires by default.
+ * In practice, only 2-4 cores should be required to transfer roughly 10 Gb/s, and each core
+ * that we use will have an initial overhead of roughly 32 MB of off-heap memory, which comes
+ * at a premium.
+ *
+ * Thus, this value should still retain maximum throughput and reduce wasted off-heap memory
+ * allocation. It can be overridden by setting the number of serverThreads and clientThreads
+ * manually in Spark's configuration.
+ */
+ private val MAX_DEFAULT_NETTY_THREADS = 8
+
+ /**
+ * Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ * @param numUsableCores if nonzero, this will restrict the server and client threads to only
+ * use the given number of cores, rather than all of the machine's cores.
+ * This restriction will only occur if these properties are not already set.
+ */
+ def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = {
+ val conf = _conf.clone
+
+ // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily
+ // assuming we have all the machine's cores).
+ // NB: Only set if serverThreads/clientThreads not already set.
+ val numThreads = defaultNumThreads(numUsableCores)
+ conf.set("spark.shuffle.io.serverThreads",
+ conf.get("spark.shuffle.io.serverThreads", numThreads.toString))
+ conf.set("spark.shuffle.io.clientThreads",
+ conf.get("spark.shuffle.io.clientThreads", numThreads.toString))
+
+ new TransportConf(new ConfigProvider {
+ override def get(name: String): String = conf.get(name)
+ })
+ }
+
+ /**
+ * Returns the default number of threads for both the Netty client and server thread pools.
+ * If numUsableCores is 0, we will use Runtime get an approximate number of available cores.
+ */
+ private def defaultNumThreads(numUsableCores: Int): Int = {
+ val availableCores =
+ if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
+ math.min(availableCores, MAX_DEFAULT_NETTY_THREADS)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
deleted file mode 100644
index e28219dd7745b..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.EventListener
-
-
-trait BlockClientListener extends EventListener {
-
- def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit
-
- def onFetchFailure(blockId: String, errorMsg: String): Unit
-
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
deleted file mode 100644
index 5aea7ba2f3673..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.concurrent.TimeoutException
-
-import io.netty.bootstrap.Bootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.handler.codec.LengthFieldBasedFrameDecoder
-import io.netty.handler.codec.string.StringEncoder
-import io.netty.util.CharsetUtil
-
-import org.apache.spark.Logging
-
-/**
- * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]].
- * Use [[BlockFetchingClientFactory]] to instantiate this client.
- *
- * The constructor blocks until a connection is successfully established.
- *
- * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-@throws[TimeoutException]
-private[spark]
-class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
- extends Logging {
-
- private val handler = new BlockFetchingClientHandler
-
- /** Netty Bootstrap for creating the TCP connection. */
- private val bootstrap: Bootstrap = {
- val b = new Bootstrap
- b.group(factory.workerGroup)
- .channel(factory.socketChannelClass)
- // Use pooled buffers to reduce temporary buffer allocation
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- // Disable Nagle's Algorithm since we don't want packets to wait
- .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
- .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
- .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)
-
- b.handler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8))
- // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
- .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
- .addLast("handler", handler)
- }
- })
- b
- }
-
- /** Netty ChannelFuture for the connection. */
- private val cf: ChannelFuture = bootstrap.connect(hostname, port)
- if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
- throw new TimeoutException(
- s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
- }
-
- /**
- * Ask the remote server for a sequence of blocks, and execute the callback.
- *
- * Note that this is asynchronous and returns immediately. Upstream caller should throttle the
- * rate of fetching; otherwise we could run out of memory.
- *
- * @param blockIds sequence of block ids to fetch.
- * @param listener callback to fire on fetch success / failure.
- */
- def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
- // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
- // It's also best to limit the number of "flush" calls since it requires system calls.
- // Let's concatenate the string and then call writeAndFlush once.
- // This is also why this implementation might be more efficient than multiple, separate
- // fetch block calls.
- var startTime: Long = 0
- logTrace {
- startTime = System.nanoTime
- s"Sending request $blockIds to $hostname:$port"
- }
-
- blockIds.foreach { blockId =>
- handler.addRequest(blockId, listener)
- }
-
- val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
- writeFuture.addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace {
- val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
- s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
- }
- } else {
- // Fail all blocks.
- val errorMsg =
- s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
- logError(errorMsg, future.cause)
- blockIds.foreach { blockId =>
- listener.onFetchFailure(blockId, errorMsg)
- handler.removeRequest(blockId)
- }
- }
- }
- })
- }
-
- def waitForClose(): Unit = {
- cf.channel().closeFuture().sync()
- }
-
- def close(): Unit = cf.channel().close()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
deleted file mode 100644
index 2b28402c52b49..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.nio.NioSocketChannel
-import io.netty.channel.socket.oio.OioSocketChannel
-import io.netty.channel.{EventLoopGroup, Channel}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.util.Utils
-
-/**
- * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses
- * the worker thread pool for Netty.
- *
- * Concurrency: createClient is safe to be called from multiple threads concurrently.
- */
-private[spark]
-class BlockFetchingClientFactory(val conf: NettyConfig) {
-
- def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
-
- /** A thread factory so the threads are named (for debugging). */
- val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
-
- /** The following two are instantiated by the [[init]] method, depending ioMode. */
- var socketChannelClass: Class[_ <: Channel] = _
- var workerGroup: EventLoopGroup = _
-
- init()
-
- /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
- private def init(): Unit = {
- def initOio(): Unit = {
- socketChannelClass = classOf[OioSocketChannel]
- workerGroup = new OioEventLoopGroup(0, threadFactory)
- }
- def initNio(): Unit = {
- socketChannelClass = classOf[NioSocketChannel]
- workerGroup = new NioEventLoopGroup(0, threadFactory)
- }
- def initEpoll(): Unit = {
- socketChannelClass = classOf[EpollSocketChannel]
- workerGroup = new EpollEventLoopGroup(0, threadFactory)
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
- }
-
- /**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
- *
- * This blocks until a connection is successfully established.
- *
- * Concurrency: This method is safe to call from multiple threads.
- */
- def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {
- new BlockFetchingClient(this, remoteHost, remotePort)
- }
-
- def stop(): Unit = {
- if (workerGroup != null) {
- workerGroup.shutdownGracefully()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
deleted file mode 100644
index 83265b164299d..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.Logging
-
-
-/**
- * Handler that processes server responses. It uses the protocol documented in
- * [[org.apache.spark.network.netty.server.BlockServer]].
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-private[client]
-class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {
-
- /** Tracks the list of outstanding requests and their listeners on success/failure. */
- private val outstandingRequests = java.util.Collections.synchronizedMap {
- new java.util.HashMap[String, BlockClientListener]
- }
-
- def addRequest(blockId: String, listener: BlockClientListener): Unit = {
- outstandingRequests.put(blockId, listener)
- }
-
- def removeRequest(blockId: String): Unit = {
- outstandingRequests.remove(blockId)
- }
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
- logError(errorMsg, cause)
-
- // Fire the failure callback for all outstanding blocks
- outstandingRequests.synchronized {
- val iter = outstandingRequests.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- entry.getValue.onFetchFailure(entry.getKey, errorMsg)
- }
- outstandingRequests.clear()
- }
-
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
- val totalLen = in.readInt()
- val blockIdLen = in.readInt()
- val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
- in.readBytes(blockIdBytes)
- val blockId = new String(blockIdBytes)
- val blockSize = totalLen - math.abs(blockIdLen) - 4
-
- def server = ctx.channel.remoteAddress.toString
-
- // blockIdLen is negative when it is an error message.
- if (blockIdLen < 0) {
- val errorMessageBytes = new Array[Byte](blockSize)
- in.readBytes(errorMessageBytes)
- val errorMsg = new String(errorMessageBytes)
- logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchFailure(blockId, errorMsg)
- }
- } else {
- logTrace(s"Received block $blockId ($blockSize B) from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
deleted file mode 100644
index 9740ee64d1f2d..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-/**
- * A simple iterator that lazily initializes the underlying iterator.
- *
- * The use case is that sometimes we might have many iterators open at the same time, and each of
- * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer).
- * This could lead to too many buffers open. If this iterator is used, we lazily initialize those
- * buffers.
- */
-private[spark]
-class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] {
-
- lazy val proxy = createIterator
-
- override def hasNext: Boolean = {
- val gotNext = proxy.hasNext
- if (!gotNext) {
- close()
- }
- gotNext
- }
-
- override def next(): Any = proxy.next()
-
- def close(): Unit = Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
deleted file mode 100644
index ea1abf5eccc26..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.io.InputStream
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{ByteBuf, ByteBufInputStream}
-
-
-/**
- * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty.
- * This is a Scala value class.
- *
- * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of
- * reference by the retain method and release method.
- */
-private[spark]
-class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal {
-
- /** Return the nio ByteBuffer view of the underlying buffer. */
- def byteBuffer(): ByteBuffer = underlying.nioBuffer
-
- /** Creates a new input stream that starts from the current position of the buffer. */
- def inputStream(): InputStream = new ByteBufInputStream(underlying)
-
- /** Increment the reference counter by one. */
- def retain(): Unit = underlying.retain()
-
- /** Decrement the reference counter by one and release the buffer if the ref count is 0. */
- def release(): Unit = underlying.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
deleted file mode 100644
index 162e9cc6828d4..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-/**
- * Header describing a block. This is used only in the server pipeline.
- *
- * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it.
- *
- * @param blockSize length of the block content, excluding the length itself.
- * If positive, this is the header for a block (not part of the header).
- * If negative, this is the header and content for an error message.
- * @param blockId block id
- * @param error some error message from reading the block
- */
-private[server]
-class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
deleted file mode 100644
index 8e4dda4ef8595..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.ChannelHandlerContext
-import io.netty.handler.codec.MessageToByteEncoder
-
-/**
- * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol.
- */
-private[server]
-class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] {
- override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = {
- // message = message length (4 bytes) + block id length (4 bytes) + block id + block data
- // message length = block id length (4 bytes) + size of block id + size of block data
- val blockIdBytes = msg.blockId.getBytes
- msg.error match {
- case Some(errorMsg) =>
- val errorBytes = errorMsg.getBytes
- out.writeInt(4 + blockIdBytes.length + errorBytes.size)
- out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors
- out.writeBytes(blockIdBytes) // next is blockId itself
- out.writeBytes(errorBytes) // error message
- case None =>
- out.writeInt(4 + blockIdBytes.length + msg.blockSize)
- out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length
- out.writeBytes(blockIdBytes) // next is blockId itself
- // msg of size blockSize will be written by ServerHandler
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
deleted file mode 100644
index 7b2f9a8d4dfd0..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.net.InetSocketAddress
-
-import io.netty.bootstrap.ServerBootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.socket.nio.NioServerSocketChannel
-import io.netty.channel.socket.oio.OioServerSocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-import io.netty.util.CharsetUtil
-
-import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.storage.BlockDataProvider
-import org.apache.spark.util.Utils
-
-
-/**
- * Server for serving Spark data blocks.
- * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]].
- *
- * Protocol for requesting blocks (client to server):
- * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n"
- *
- * Protocol for sending blocks (server to client):
- * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data.
- *
- * frame-length should not include the length of itself.
- * If block-id-length is negative, then this is an error message rather than block-data. The real
- * length is the absolute value of the frame-length.
- *
- */
-private[spark]
-class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging {
-
- def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = {
- this(new NettyConfig(sparkConf), dataProvider)
- }
-
- def port: Int = _port
-
- def hostName: String = _hostName
-
- private var _port: Int = conf.serverPort
- private var _hostName: String = ""
- private var bootstrap: ServerBootstrap = _
- private var channelFuture: ChannelFuture = _
-
- init()
-
- /** Initialize the server. */
- private def init(): Unit = {
- bootstrap = new ServerBootstrap
- val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
- val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
-
- // Use only one thread to accept connections, and 2 * num_cores for worker.
- def initNio(): Unit = {
- val bossGroup = new NioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new NioEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
- }
- def initOio(): Unit = {
- val bossGroup = new OioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new OioEventLoopGroup(0, workerThreadFactory)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel])
- }
- def initEpoll(): Unit = {
- val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
-
- // Use pooled buffers to reduce temporary buffer allocation
- bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
-
- // Various (advanced) user-configured settings.
- conf.backLog.foreach { backLog =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog)
- }
- conf.receiveBuf.foreach { receiveBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf)
- }
- conf.sendBuf.foreach { sendBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf)
- }
-
- bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
- })
-
- channelFuture = bootstrap.bind(new InetSocketAddress(_port))
- channelFuture.sync()
-
- val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
- _port = addr.getPort
- _hostName = addr.getHostName
- }
-
- /** Shutdown the server. */
- def stop(): Unit = {
- if (channelFuture != null) {
- channelFuture.channel().close().awaitUninterruptibly()
- channelFuture = null
- }
- if (bootstrap != null && bootstrap.group() != null) {
- bootstrap.group().shutdownGracefully()
- }
- if (bootstrap != null && bootstrap.childGroup() != null) {
- bootstrap.childGroup().shutdownGracefully()
- }
- bootstrap = null
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
deleted file mode 100644
index cc70bd0c5c477..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.channel.ChannelInitializer
-import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-import io.netty.util.CharsetUtil
-import org.apache.spark.storage.BlockDataProvider
-
-
-/** Channel initializer that sets up the pipeline for the BlockServer. */
-private[netty]
-class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
- extends ChannelInitializer[SocketChannel] {
-
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
deleted file mode 100644
index 40dd5e5d1a2ac..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.FileInputStream
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-
-import io.netty.buffer.Unpooled
-import io.netty.channel._
-
-import org.apache.spark.Logging
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * A handler that processes requests from clients and writes block data back.
- *
- * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first
- * so channelRead0 is called once per line (i.e. per block id).
- */
-private[server]
-class BlockServerHandler(dataProvider: BlockDataProvider)
- extends SimpleChannelInboundHandler[String] with Logging {
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
- def client = ctx.channel.remoteAddress.toString
-
- // A helper function to send error message back to the client.
- def respondWithError(error: String): Unit = {
- ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (!future.isSuccess) {
- // TODO: Maybe log the success case as well.
- logError(s"Error sending error back to $client", future.cause)
- ctx.close()
- }
- }
- }
- )
- }
-
- def writeFileSegment(segment: FileSegment): Unit = {
- // Send error message back if the block is too large. Even though we are capable of sending
- // large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
- // Once we fixed the receiving end to be able to process large blocks, this should be removed.
- // Also make sure we update BlockHeaderEncoder to support length > 2G.
-
- // See [[BlockHeaderEncoder]] for the way length is encoded.
- if (segment.length + blockId.length + 4 > Int.MaxValue) {
- respondWithError(s"Block $blockId size ($segment.length) greater than 2G")
- return
- }
-
- var fileChannel: FileChannel = null
- try {
- fileChannel = new FileInputStream(segment.file).getChannel
- } catch {
- case e: Exception =>
- logError(
- s"Error opening channel for $blockId in ${segment.file} for request from $client", e)
- respondWithError(e.getMessage)
- }
-
- // Found the block. Send it back.
- if (fileChannel != null) {
- // Write the header and block data. In the case of failures, the listener on the block data
- // write should close the connection.
- ctx.write(new BlockHeader(segment.length.toInt, blockId))
-
- val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length)
- ctx.writeAndFlush(region).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${segment.length} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
- }
-
- def writeByteBuffer(buf: ByteBuffer): Unit = {
- ctx.write(new BlockHeader(buf.remaining, blockId))
- ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
-
- logTrace(s"Received request from $client to fetch block $blockId")
-
- var blockData: Either[FileSegment, ByteBuffer] = null
-
- // First make sure we can find the block. If not, send error back to the user.
- try {
- blockData = dataProvider.getBlockData(blockId)
- } catch {
- case e: Exception =>
- logError(s"Error opening block $blockId for request from $client", e)
- respondWithError(e.getMessage)
- return
- }
-
- blockData match {
- case Left(segment) => writeFileSegment(segment)
- case Right(buf) => writeByteBuffer(buf)
- }
-
- } // end of channelRead0
-}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index f368209980f93..c2d9578be7ebb 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -20,11 +20,15 @@ package org.apache.spark.network.nio
import java.net._
import java.nio._
import java.nio.channels._
+import java.util.concurrent.ConcurrentLinkedQueue
import java.util.LinkedList
-import org.apache.spark._
-
+import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.util.control.NonFatal
+
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
@@ -51,7 +55,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
@volatile private var closed = false
var onCloseCallback: Connection => Unit = null
- var onExceptionCallback: (Connection, Exception) => Unit = null
+ val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit]
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
val remoteAddress = getRemoteAddress()
@@ -130,20 +134,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
onCloseCallback = callback
}
- def onException(callback: (Connection, Exception) => Unit) {
- onExceptionCallback = callback
+ def onException(callback: (Connection, Throwable) => Unit) {
+ onExceptionCallbacks.add(callback)
}
def onKeyInterestChange(callback: (Connection, Int) => Unit) {
onKeyInterestChangeCallback = callback
}
- def callOnExceptionCallback(e: Exception) {
- if (onExceptionCallback != null) {
- onExceptionCallback(this, e)
- } else {
- logError("Error in connection to " + getRemoteConnectionManagerId() +
- " and OnExceptionCallback not registered", e)
+ def callOnExceptionCallbacks(e: Throwable) {
+ onExceptionCallbacks foreach {
+ callback =>
+ try {
+ callback(this, e)
+ } catch {
+ case NonFatal(e) => {
+ logWarning("Ignored error in onExceptionCallback", e)
+ }
+ }
}
}
@@ -323,7 +331,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
} catch {
case e: Exception => {
logError("Error connecting to " + address, e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
}
}
}
@@ -348,7 +356,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
}
}
true
@@ -393,7 +401,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
} catch {
case e: Exception => {
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
close()
return false
}
@@ -420,7 +428,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
case e: Exception =>
logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(),
e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
close()
}
@@ -577,7 +585,7 @@ private[spark] class ReceivingConnection(
} catch {
case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
- callOnExceptionCallback(e)
+ callOnExceptionCallbacks(e)
close()
return false
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 01cd27a907eea..df4b085d2251e 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -18,22 +18,28 @@
package org.apache.spark.network.nio
import java.io.IOException
+import java.lang.ref.WeakReference
import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
-import java.util.{Timer, TimerTask}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
+import com.google.common.base.Charsets.UTF_8
+import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}
+
import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
import org.apache.spark.util.Utils
+import scala.util.Try
+import scala.util.control.NonFatal
private[nio] class ConnectionManager(
port: Int,
@@ -51,19 +57,29 @@ private[nio] class ConnectionManager(
class MessageStatus(
val message: Message,
val connectionManagerId: ConnectionManagerId,
- completionHandler: MessageStatus => Unit) {
+ completionHandler: Try[Message] => Unit) {
+
+ def success(ackMessage: Message) {
+ if (ackMessage == null) {
+ failure(new NullPointerException)
+ }
+ else {
+ completionHandler(scala.util.Success(ackMessage))
+ }
+ }
- /** This is non-None if message has been ack'd */
- var ackMessage: Option[Message] = None
+ def failWithoutAck() {
+ completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd")))
+ }
- def markDone(ackMessage: Option[Message]) {
- this.ackMessage = ackMessage
- completionHandler(this)
+ def failure(e: Throwable) {
+ completionHandler(scala.util.Failure(e))
}
}
private val selector = SelectorProvider.provider.openSelector()
- private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
+ private val ackTimeoutMonitor =
+ new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
@@ -72,14 +88,32 @@ private[nio] class ConnectionManager(
conf.getInt("spark.core.connection.handler.threads.max", 60),
conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-message-executor"))
+ Utils.namedThreadFactory("handle-message-executor")) {
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+ if (t != null && NonFatal(t)) {
+ logError("Error in handleMessageExecutor is not handled properly", t)
+ }
+ }
+
+ }
private val handleReadWriteExecutor = new ThreadPoolExecutor(
conf.getInt("spark.core.connection.io.threads.min", 4),
conf.getInt("spark.core.connection.io.threads.max", 32),
conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-read-write-executor"))
+ Utils.namedThreadFactory("handle-read-write-executor")) {
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+ if (t != null && NonFatal(t)) {
+ logError("Error in handleReadWriteExecutor is not handled properly", t)
+ }
+ }
+
+ }
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks :
// which should be executed asap
@@ -88,7 +122,16 @@ private[nio] class ConnectionManager(
conf.getInt("spark.core.connection.connect.threads.max", 8),
conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
- Utils.namedThreadFactory("handle-connect-executor"))
+ Utils.namedThreadFactory("handle-connect-executor")) {
+
+ override def afterExecute(r: Runnable, t: Throwable): Unit = {
+ super.afterExecute(r, t)
+ if (t != null && NonFatal(t)) {
+ logError("Error in handleConnectExecutor is not handled properly", t)
+ }
+ }
+
+ }
private val serverChannel = ServerSocketChannel.open()
// used to track the SendingConnections waiting to do SASL negotiation
@@ -98,7 +141,10 @@ private[nio] class ConnectionManager(
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
with SynchronizedMap[ConnectionManagerId, SendingConnection]
- private val messageStatuses = new HashMap[Int, MessageStatus]
+ // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this
+ // map when messages are sent and are removed when acknowledgement messages are received or when
+ // acknowledgement timeouts expire
+ private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
@@ -153,17 +199,24 @@ private[nio] class ConnectionManager(
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
- var register: Boolean = false
try {
- register = conn.write()
- } finally {
- writeRunnableStarted.synchronized {
- writeRunnableStarted -= key
- val needReregister = register || conn.resetForceReregister()
- if (needReregister && conn.changeInterestForWrite()) {
- conn.registerInterest()
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ val needReregister = register || conn.resetForceReregister()
+ if (needReregister && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
}
}
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e)
+ conn.callOnExceptionCallbacks(e)
+ }
}
}
} )
@@ -187,16 +240,23 @@ private[nio] class ConnectionManager(
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
- var register: Boolean = false
try {
- register = conn.read()
- } finally {
- readRunnableStarted.synchronized {
- readRunnableStarted -= key
- if (register && conn.changeInterestForRead()) {
- conn.registerInterest()
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
}
}
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e)
+ conn.callOnExceptionCallbacks(e)
+ }
}
}
} )
@@ -213,19 +273,25 @@ private[nio] class ConnectionManager(
handleConnectExecutor.execute(new Runnable {
override def run() {
+ try {
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
- var tries: Int = 10
- while (tries >= 0) {
- if (conn.finishConnect(false)) return
- // Sleep ?
- Thread.sleep(1)
- tries -= 1
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need
+ // not succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e)
+ conn.callOnExceptionCallbacks(e)
+ }
}
-
- // fallback to previous behavior : we should not really come here since this method was
- // triggered since channel became connectable : but at times, the first finishConnect need
- // not succeed : hence the loop to retry a few 'times'.
- conn.finishConnect(true)
}
} )
}
@@ -246,16 +312,16 @@ private[nio] class ConnectionManager(
handleConnectExecutor.execute(new Runnable {
override def run() {
try {
- conn.callOnExceptionCallback(e)
+ conn.callOnExceptionCallbacks(e)
} catch {
// ignore exceptions
- case e: Exception => logDebug("Ignoring exception", e)
+ case NonFatal(e) => logDebug("Ignoring exception", e)
}
try {
conn.close()
} catch {
// ignore exceptions
- case e: Exception => logDebug("Ignoring exception", e)
+ case NonFatal(e) => logDebug("Ignoring exception", e)
}
}
})
@@ -448,7 +514,7 @@ private[nio] class ConnectionManager(
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
.foreach(status => {
logInfo("Notifying " + status)
- status.markDone(None)
+ status.failWithoutAck()
})
messageStatuses.retain((i, status) => {
@@ -477,7 +543,7 @@ private[nio] class ConnectionManager(
for (s <- messageStatuses.values
if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
- s.markDone(None)
+ s.failWithoutAck()
}
messageStatuses.retain((i, status) => {
@@ -492,7 +558,7 @@ private[nio] class ConnectionManager(
}
}
- def handleConnectionError(connection: Connection, e: Exception) {
+ def handleConnectionError(connection: Connection, e: Throwable) {
logInfo("Handling connection error on connection to " +
connection.getRemoteConnectionManagerId())
removeConnection(connection)
@@ -510,9 +576,17 @@ private[nio] class ConnectionManager(
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
- logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
- handleMessage(connectionManagerId, message, connection)
- logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ try {
+ logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ handleMessage(connectionManagerId, message, connection)
+ logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ } catch {
+ case NonFatal(e) => {
+ logError("Error when handling messages from " +
+ connection.getRemoteConnectionManagerId(), e)
+ connection.callOnExceptionCallbacks(e)
+ }
+ }
}
}
handleMessageExecutor.execute(runnable)
@@ -532,7 +606,7 @@ private[nio] class ConnectionManager(
} else {
var replyToken : Array[Byte] = null
try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
@@ -566,7 +640,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -651,7 +725,7 @@ private[nio] class ConnectionManager(
messageStatuses.get(bufferMessage.ackId) match {
case Some(status) => {
messageStatuses -= bufferMessage.ackId
- status.markDone(Some(message))
+ status.success(message)
}
case None => {
/**
@@ -691,9 +765,7 @@ private[nio] class ConnectionManager(
} catch {
case e: Exception => {
logError(s"Exception was thrown while processing message", e)
- val m = Message.createBufferMessage(bufferMessage.id)
- m.hasError = true
- ackMessage = Some(m)
+ ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id))
}
} finally {
sendMessage(connectionManagerId, ackMessage.getOrElse {
@@ -712,7 +784,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
@@ -770,6 +842,12 @@ private[nio] class ConnectionManager(
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
newConnectionId, securityManager)
+ newConnection.onException {
+ case (conn, e) => {
+ logError("Exception while sending message.", e)
+ reportSendingMessageFailure(message.id, e)
+ }
+ }
logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
@@ -782,13 +860,36 @@ private[nio] class ConnectionManager(
"connectionid: " + connection.connectionId)
if (authEnabled) {
- checkSendAuthFirst(connectionManagerId, connection)
+ try {
+ checkSendAuthFirst(connectionManagerId, connection)
+ } catch {
+ case NonFatal(e) => {
+ reportSendingMessageFailure(message.id, e)
+ }
+ }
}
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
wakeupSelector()
}
+ private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = {
+ // need to tell sender it failed
+ messageStatuses.synchronized {
+ val s = messageStatuses.get(messageId)
+ s match {
+ case Some(msgStatus) => {
+ messageStatuses -= messageId
+ logInfo("Notifying " + msgStatus.connectionManagerId)
+ msgStatus.failure(e)
+ }
+ case None => {
+ logError("no messageStatus for failed message id: " + messageId)
+ }
+ }
+ }
+ }
+
private def wakeupSelector() {
selector.wakeup()
}
@@ -803,29 +904,62 @@ private[nio] class ConnectionManager(
: Future[Message] = {
val promise = Promise[Message]()
- val timeoutTask = new TimerTask {
- override def run(): Unit = {
+ // It's important that the TimerTask doesn't capture a reference to `message`, which can cause
+ // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time
+ // at which they would originally be scheduled to run. Therefore, extract the message id
+ // from outside of the TimerTask closure (see SPARK-4393 for more context).
+ val messageId = message.id
+ // Keep a weak reference to the promise so that the completed promise may be garbage-collected
+ val promiseReference = new WeakReference(promise)
+ val timeoutTask: TimerTask = new TimerTask {
+ override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized {
- messageStatuses.remove(message.id).foreach ( s => {
- promise.failure(
- new IOException("sendMessageReliably failed because ack " +
- s"was not received within $ackTimeout sec"))
- })
+ messageStatuses.remove(messageId).foreach { s =>
+ val e = new IOException("sendMessageReliably failed because ack " +
+ s"was not received within $ackTimeout sec")
+ val p = promiseReference.get
+ if (p != null) {
+ // Attempt to fail the promise with a Timeout exception
+ if (!p.tryFailure(e)) {
+ // If we reach here, then someone else has already signalled success or failure
+ // on this promise, so log a warning:
+ logError("Ignore error because promise is completed", e)
+ }
+ } else {
+ // The WeakReference was empty, which should never happen because
+ // sendMessageReliably's caller should have a strong reference to promise.future;
+ logError("Promise was garbage collected; this should never happen!", e)
+ }
+ }
}
}
}
+ val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)
+
val status = new MessageStatus(message, connectionManagerId, s => {
- timeoutTask.cancel()
- s.ackMessage match {
- case None => // Indicates a failure where we either never sent or never got ACK'd
- promise.failure(new IOException("sendMessageReliably failed without being ACK'd"))
- case Some(ackMessage) =>
+ timeoutTaskHandle.cancel()
+ s match {
+ case scala.util.Failure(e) =>
+ // Indicates a failure where we either never sent or never got ACK'd
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore error because promise is completed", e)
+ }
+ case scala.util.Success(ackMessage) =>
if (ackMessage.hasError) {
- promise.failure(
- new IOException("sendMessageReliably failed with ACK that signalled a remote error"))
+ val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head
+ val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit())
+ errorMsgByteBuf.get(errorMsgBytes)
+ val errorMsg = new String(errorMsgBytes, UTF_8)
+ val e = new IOException(
+ s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg")
+ if (!promise.tryFailure(e)) {
+ logWarning("Ignore error because promise is completed", e)
+ }
} else {
- promise.success(ackMessage)
+ if (!promise.trySuccess(ackMessage)) {
+ logWarning("Drop ackMessage because promise is completed")
+ }
}
}
})
@@ -833,7 +967,6 @@ private[nio] class ConnectionManager(
messageStatuses += ((message.id, status))
}
- ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message)
promise.future
}
@@ -843,7 +976,7 @@ private[nio] class ConnectionManager(
}
def stop() {
- ackTimeoutMonitor.cancel()
+ ackTimeoutMonitor.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
index 0b874c2891255..fb4a979b824c3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
@@ -22,6 +22,9 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
+import com.google.common.base.Charsets.UTF_8
+
+import org.apache.spark.util.Utils
private[nio] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@@ -84,6 +87,19 @@ private[nio] object Message {
createBufferMessage(new Array[ByteBuffer](0), ackId)
}
+ /**
+ * Create a "negative acknowledgment" to notify a sender that an error occurred
+ * while processing its message. The exception's stacktrace will be formatted
+ * as a string, serialized into a byte array, and sent as the message payload.
+ */
+ def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
+ val exceptionString = Utils.exceptionString(exception)
+ val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8))
+ val errorMessage = createBufferMessage(serializedExceptionString, ackId)
+ errorMessage.hasError = true
+ errorMessage
+ }
+
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
case BUFFER_MESSAGE => new BufferMessage(header.id,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index b389b9a2022c6..b2aec160635c7 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -19,12 +19,14 @@ package org.apache.spark.network.nio
import java.nio.ByteBuffer
-import scala.concurrent.Future
-
-import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+
+import scala.concurrent.Future
/**
@@ -71,20 +73,21 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
/**
* Tear down the transfer service.
*/
- override def stop(): Unit = {
+ override def close(): Unit = {
if (cm != null) {
cm.stop()
}
}
override def fetchBlocks(
- hostName: String,
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
checkInit()
- val cmId = new ConnectionManagerId(hostName, port)
+ val cmId = new ConnectionManagerId(host, port)
val blockMessageArray = new BlockMessageArray(blockIds.map { blockId =>
BlockMessage.fromGetBlock(GetBlock(BlockId(blockId)))
})
@@ -96,21 +99,33 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- listener.onBlockFetchFailure(
- new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId"))
- } else {
- val blockId = blockMessage.getId
- val networkSize = blockMessage.getData.limit()
- listener.onBlockFetchSuccess(
- blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData))
+ // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty.
+ if (blockMessageArray.isEmpty) {
+ blockIds.foreach { id =>
+ listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId"))
+ }
+ } else {
+ for (blockMessage: BlockMessage <- blockMessageArray) {
+ val msgType = blockMessage.getType
+ if (msgType != BlockMessage.TYPE_GOT_BLOCK) {
+ if (blockMessage.getId != null) {
+ listener.onBlockFetchFailure(blockMessage.getId.toString,
+ new SparkException(s"Unexpected message $msgType received from $cmId"))
+ }
+ } else {
+ val blockId = blockMessage.getId
+ val networkSize = blockMessage.getData.limit()
+ listener.onBlockFetchSuccess(
+ blockId.toString, new NioManagedBuffer(blockMessage.getData))
+ }
}
}
}(cm.futureExecContext)
future.onFailure { case exception =>
- listener.onBlockFetchFailure(exception)
+ blockIds.foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, exception)
+ }
}(cm.futureExecContext)
}
@@ -122,12 +137,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ execId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
: Future[Unit] = {
checkInit()
- val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
+ val msg = PutBlock(blockId, blockData.nioByteBuffer(), level)
val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg))
val remoteCmId = new ConnectionManagerId(hostName, port)
val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage)
@@ -149,19 +165,15 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Exception handling buffer message", e)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
- }
+ Some(Message.createErrorMessage(e, msg.id))
}
case otherMessage: Any =>
- logError("Unknown type message received: " + otherMessage)
- val errorMessage = Message.createBufferMessage(msg.id)
- errorMessage.hasError = true
- Some(errorMessage)
+ val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}"
+ logError(errorMsg)
+ Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id))
}
}
@@ -170,13 +182,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
case BlockMessage.TYPE_PUT_BLOCK =>
val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + msg + "]")
- putBlock(msg.id.toString, msg.data, msg.level)
+ putBlock(msg.id, msg.data, msg.level)
None
case BlockMessage.TYPE_GET_BLOCK =>
val msg = new GetBlock(blockMessage.getId)
logDebug("Received [" + msg + "]")
- val buffer = getBlock(msg.id.toString)
+ val buffer = getBlock(msg.id)
if (buffer == null) {
return None
}
@@ -186,20 +198,20 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
}
- private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes)
- blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level)
+ blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level)
logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " with data size: " + bytes.limit)
}
- private def getBlock(blockId: String): ByteBuffer = {
+ private def getBlock(blockId: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + blockId + " started from " + startTimeMs)
- val buffer = blockDataManager.getBlockData(blockId).orNull
+ val buffer = blockDataManager.getBlockData(blockId)
logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- if (buffer == null) null else buffer.nioByteBuffer()
+ buffer.nioByteBuffer()
}
}
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index e2fc9c649925e..436dbed1730bc 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -44,5 +44,5 @@ package org.apache
package object spark {
// For package docs only
- val SPARK_VERSION = "1.2.0-SNAPSHOT"
+ val SPARK_VERSION = "1.3.0-SNAPSHOT"
}
diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
index 3155dfe165664..637492a97551b 100644
--- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.NormalDistribution
/**
* An ApproximateEvaluator for counts.
@@ -46,7 +46,8 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
val mean = (sum + 1 - p) / p
val variance = (sum + 1) * (1 - p) / (p * p)
val stdev = math.sqrt(variance)
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val confFactor = new NormalDistribution().
+ inverseCumulativeProbability(1 - (1 - confidence) / 2)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
new BoundedDouble(mean, confidence, low, high)
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
index 8bb78123e3c9c..3ef3cc219dec6 100644
--- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
@@ -24,7 +24,7 @@ import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.NormalDistribution
import org.apache.spark.util.collection.OpenHashMap
@@ -55,7 +55,8 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf
new HashMap[T, BoundedDouble]
} else {
val p = outputsMerged.toDouble / totalOutputs
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val confFactor = new NormalDistribution().
+ inverseCumulativeProbability(1 - (1 - confidence) / 2)
val result = new JHashMap[T, BoundedDouble](sums.size)
sums.foreach { case (key, sum) =>
val mean = (sum + 1 - p) / p
diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
index d24959cba8727..787a21a61fdcf 100644
--- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution}
import org.apache.spark.util.StatCounter
@@ -45,9 +45,10 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
val stdev = math.sqrt(counter.sampleVariance / counter.count)
val confFactor = {
if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
+ new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ val degreesOfFreedom = (counter.count - 1).toInt
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
}
val low = mean - confFactor * stdev
diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
index 92915ee66d29f..828bf96c2c0bd 100644
--- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
+++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
/**
* A utility class for caching Student's T distribution values for a given confidence level
@@ -25,8 +25,10 @@ import cern.jet.stat.Probability
* confidence intervals for many keys.
*/
private[spark] class StudentTCacher(confidence: Double) {
+
val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
- val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
+
+ val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
def get(sampleSize: Long): Double = {
@@ -35,7 +37,8 @@ private[spark] class StudentTCacher(confidence: Double) {
} else {
val size = sampleSize.toInt
if (cache(size) < 0) {
- cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
+ val tDist = new TDistribution(size - 1)
+ cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
cache(size)
}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
index d5336284571d2..1753c2561b678 100644
--- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
import org.apache.spark.util.StatCounter
@@ -55,9 +55,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
val sumStdev = math.sqrt(sumVar)
val confFactor = {
if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
+ new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ val degreesOfFreedom = (counter.count - 1).toInt
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
}
val low = sumEstimate - confFactor * sumStdev
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index b62f3fbdc4a15..9f9f10b7ebc3a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -24,14 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.reflect.ClassTag
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
-import org.apache.spark.annotation.Experimental
/**
- * :: Experimental ::
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-@Experimental
class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
/**
@@ -78,16 +75,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
- // If we didn't find any rows after the first iteration, just try all partitions next.
+ // If we didn't find any rows after the previous iteration, quadruple and retry.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
- // by 50%.
+ // by 50%. We also cap the estimation in the end.
if (results.size == 0) {
- numPartsToTry = totalParts - 1
+ numPartsToTry = partsScanned * 4
} else {
- numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
+ // the left side of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = Math.max(1,
+ (1.5 * num * partsScanned / results.size).toInt - partsScanned)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
new file mode 100644
index 0000000000000..6e66ddbdef788
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.hadoop.conf.{ Configurable, Configuration }
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapreduce._
+import org.apache.spark.input.StreamFileInputFormat
+import org.apache.spark.{ Partition, SparkContext }
+
+private[spark] class BinaryFileRDD[T](
+ sc: SparkContext,
+ inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
+ keyClass: Class[String],
+ valueClass: Class[T],
+ @transient conf: Configuration,
+ minPartitions: Int)
+ extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) {
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val jobContext = newJobContext(conf, jobId)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 2673ec22509e9..fffa1911f5bc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -84,5 +84,9 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds
"Attempted to use %s after its blocks have been removed!".format(toString))
}
}
+
+ protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = {
+ locations_
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 4908711d17db7..1cbd684224b7c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
private[spark]
class CartesianPartition(
@@ -36,7 +37,7 @@ class CartesianPartition(
override val index: Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
s1 = rdd1.partitions(s1Index)
s2 = rdd2.partitions(s2Index)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index fabb882cdd4b3..ffc0a8a6d67eb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
+import org.apache.spark.util.Utils
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
@@ -39,7 +40,7 @@ private[spark] case class NarrowCoGroupSplitDep(
) extends CoGroupSplitDep {
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
split = rdd.partitions(splitIndex)
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 11ebafbf6d457..9fab1d78abb04 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -25,6 +25,7 @@ import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
/**
* Class that captures a coalesced RDD by essentially keeping track of parent partitions
@@ -42,7 +43,7 @@ private[spark] case class CoalescedRDDPartition(
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent partition at the time of task serialization
parents = parentsIndices.map(rdd.partitions(_))
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 6b63eb23e9ee1..a157e36e2286e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -46,7 +46,6 @@ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
-
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
@@ -132,27 +131,47 @@ class HadoopRDD[K, V](
// used to build JobTracker ID
private val createTime = new Date()
+ private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean
+
// Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
protected def getJobConf(): JobConf = {
val conf: Configuration = broadcastedConf.value.value
- if (conf.isInstanceOf[JobConf]) {
- // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it.
- conf.asInstanceOf[JobConf]
- } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
- // getJobConf() has been called previously, so there is already a local cache of the JobConf
- // needed by this RDD.
- HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
- } else {
- // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
- // local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
- // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
- // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456).
+ if (shouldCloneJobConf) {
+ // Hadoop Configuration objects are not thread-safe, which may lead to various problems if
+ // one job modifies a configuration while another reads it (SPARK-2546). This problem occurs
+ // somewhat rarely because most jobs treat the configuration as though it's immutable. One
+ // solution, implemented here, is to clone the Configuration object. Unfortunately, this
+ // clone can be very expensive. To avoid unexpected performance regressions for workloads and
+ // Hadoop versions that do not suffer from these thread-safety issues, this cloning is
+ // disabled by default.
HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
+ logDebug("Cloning Hadoop Configuration")
val newJobConf = new JobConf(conf)
- initLocalJobConfFuncOpt.map(f => f(newJobConf))
- HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+ if (!conf.isInstanceOf[JobConf]) {
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
+ }
newJobConf
}
+ } else {
+ if (conf.isInstanceOf[JobConf]) {
+ logDebug("Re-using user-broadcasted JobConf")
+ conf.asInstanceOf[JobConf]
+ } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
+ logDebug("Re-using cached JobConf")
+ HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
+ } else {
+ // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the
+ // local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
+ // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
+ // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456).
+ HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
+ logDebug("Creating new JobConf and caching it for later re-use")
+ val newJobConf = new JobConf(conf)
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
+ HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+ newJobConf
+ }
+ }
}
}
@@ -192,11 +211,25 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
val jobConf = getJobConf()
+
+ val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
+ val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
+ split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf)
+ } else {
+ None
+ }
+ if (bytesReadCallback.isDefined) {
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ }
+
+ var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf)
+ context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
@@ -204,18 +237,7 @@ class HadoopRDD[K, V](
val key: K = reader.createKey()
val value: V = reader.createValue()
- // Set the task input metrics.
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- try {
- /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't
- * always at record boundaries, so tasks may need to read into other splits to complete
- * a record. */
- inputMetrics.bytesRead = split.inputSplit.value.getLength()
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
- }
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ var recordsSinceMetricsUpdate = 0
override def getNext() = {
try {
@@ -224,12 +246,36 @@ class HadoopRDD[K, V](
case eof: EOFException =>
finished = true
}
+
+ // Update bytes read metric every few records
+ if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
+ && bytesReadCallback.isDefined) {
+ recordsSinceMetricsUpdate = 0
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else {
+ recordsSinceMetricsUpdate += 1
+ }
(key, value)
}
override def close() {
try {
reader.close()
+ if (bytesReadCallback.isDefined) {
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else if (split.inputSplit.value.isInstanceOf[FileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.bytesRead = split.inputSplit.value.getLength
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
+ }
} catch {
case e: Exception => {
if (!Utils.inShutdown()) {
@@ -276,9 +322,15 @@ class HadoopRDD[K, V](
}
private[spark] object HadoopRDD extends Logging {
- /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */
+ /**
+ * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456).
+ * Therefore, we synchronize on this lock before calling new JobConf() or new Configuration().
+ */
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
+ /** Update the input bytes read metric each time this number of records has been read */
+ val RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES = 256
+
/**
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 0e38f224ac81d..642a12c1edf6c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}
import scala.reflect.ClassTag
-import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.util.NextIterator
+import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
@@ -125,5 +128,82 @@ object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
-}
+ trait ConnectionFactory extends Serializable {
+ @throws[Exception]
+ def getConnection: Connection
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results.
+ * For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
+ def create[T](
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
+
+ val jdbcRDD = new JdbcRDD[T](
+ sc.sc,
+ () => connectionFactory.getConnection,
+ sql,
+ lowerBound,
+ upperBound,
+ numPartitions,
+ (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
+
+ new JavaRDD[T](jdbcRDD)(fakeClassTag)
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
+ * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ */
+ def create(
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): JavaRDD[Array[Object]] = {
+
+ val mapRow = new JFunction[ResultSet, Array[Object]] {
+ override def call(resultSet: ResultSet): Array[Object] = {
+ resultSetToObjectArray(resultSet)
+ }
+ }
+
+ create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 0cccdefc5ee09..e55d03d391e03 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -25,6 +25,7 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.input.WholeTextFileInputFormat
@@ -34,8 +35,10 @@ import org.apache.spark.Partition
import org.apache.spark.SerializableWritable
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
+import org.apache.spark.deploy.SparkHadoopUtil
private[spark] class NewHadoopPartition(
rddId: Int,
@@ -105,6 +108,20 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
+
+ val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
+ val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
+ split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf)
+ } else {
+ None
+ }
+ if (bytesReadCallback.isDefined) {
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ }
+
val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
@@ -117,22 +134,11 @@ class NewHadoopRDD[K, V](
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- try {
- /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't
- * always at record boundaries, so tasks may need to read into other splits to complete
- * a record. */
- inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength()
- } catch {
- case e: Exception =>
- logWarning("Unable to get input split size in order to set task input bytes", e)
- }
- context.taskMetrics.inputMetrics = Some(inputMetrics)
-
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
var havePair = false
var finished = false
+ var recordsSinceMetricsUpdate = 0
override def hasNext: Boolean = {
if (!finished && !havePair) {
@@ -147,12 +153,39 @@ class NewHadoopRDD[K, V](
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
+
+ // Update bytes read metric every few records
+ if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
+ && bytesReadCallback.isDefined) {
+ recordsSinceMetricsUpdate = 0
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else {
+ recordsSinceMetricsUpdate += 1
+ }
+
(reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
+
+ // Update metrics with final amount
+ if (bytesReadCallback.isDefined) {
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
+ }
} catch {
case e: Exception => {
if (!Utils.inShutdown()) {
@@ -233,7 +266,7 @@ private[spark] class WholeTextFileRDD(
case _ =>
}
val jobContext = newJobContext(conf, jobId)
- inputFormat.setMaxSplitSize(jobContext, minPartitions)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 0d97506450a7f..8c2c959e73bb6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -28,18 +28,20 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
-RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
+RecordWriter => NewRecordWriter}
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.{DataWriteMethod, OutputMetrics}
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
@@ -315,8 +317,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
@deprecated("Use reduceByKeyLocally", "1.0.0")
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
- /** Count the number of elements for each key, and return the result to the master as a Map. */
- def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
+ /**
+ * Count the number of elements for each key, collecting the results to a local Map.
+ *
+ * Note that this method should only be used if the resulting map is expected to be small, as
+ * the whole thing is loaded into the driver's memory.
+ * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which
+ * returns an RDD[T, Long] instead of a map.
+ */
+ def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap
/**
* :: Experimental ::
@@ -954,30 +963,40 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
}
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
+ val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
- val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
- case c: Configurable => c.setConf(wrappedConf.value)
+ case c: Configurable => c.setConf(config)
case _ => ()
}
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
+
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
try {
+ var recordsWritten = 0L
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close(hadoopContext)
}
committer.commitTask(hadoopContext)
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
1
} : Int
@@ -998,6 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def saveAsHadoopDataset(conf: JobConf) {
// Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
val hadoopConf = conf
+ val wrappedConf = new SerializableWritable(hadoopConf)
val outputFormatInstance = hadoopConf.getOutputFormat
val keyClass = hadoopConf.getOutputKeyClass
val valueClass = hadoopConf.getOutputValueClass
@@ -1025,29 +1045,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.preSetup()
val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
+ val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
- writer.setup(context.getStageId, context.getPartitionId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
try {
- var count = 0
+ var recordsWritten = 0L
while (iter.hasNext) {
val record = iter.next()
- count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close()
}
writer.commit()
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
}
self.context.runJob(self, writeToFile)
writer.commitJob()
}
+ private def initHadoopOutputMetrics(context: TaskContext, config: Configuration)
+ : (OutputMetrics, Option[() => Long]) = {
+ val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir"))
+ .map(new Path(_))
+ .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config))
+ val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
+ if (bytesWrittenCallback.isDefined) {
+ context.taskMetrics.outputMetrics = Some(outputMetrics)
+ }
+ (outputMetrics, bytesWrittenCallback)
+ }
+
+ private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long],
+ outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
+ if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0
+ && bytesWrittenCallback.isDefined) {
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ }
+ }
+
/**
* Return an RDD with the keys of each tuple.
*/
@@ -1064,3 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
}
+
+private[spark] object PairRDDFunctions {
+ val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 66c71bf7e8bb5..87b22de6ae697 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -48,7 +48,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
override def index: Int = slice
@throws(classOf[IOException])
- private def writeObject(out: ObjectOutputStream): Unit = {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
@@ -67,7 +67,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
}
@throws(classOf[IOException])
- private def readObject(in: ObjectInputStream): Unit = {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index 0c2cd7a24783b..92b0641d0fb6e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
+import org.apache.spark.util.Utils
/**
* Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of
@@ -38,7 +39,7 @@ class PartitionerAwareUnionRDDPartition(
override def hashCode(): Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent partition at the time of task serialization
parents = rdds.map(_.partitions(index)).toArray
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 5d77d37378458..56ac7a69be0d3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
override def run() {
- SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
// input the pipe context firstly
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 2aba40d152e3e..3add4a76192ca 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -21,6 +21,7 @@ import java.util.{Properties, Random}
import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
+import scala.language.implicitConversions
import scala.reflect.{classTag, ClassTag}
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
@@ -28,6 +29,7 @@ import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.spark._
@@ -43,7 +45,8 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite}
import org.apache.spark.util.collection.OpenHashMap
-import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
+ SamplingUtils}
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -375,7 +378,8 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
+ new PartitionwiseSampledRDD[T, T](
+ this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -927,32 +931,15 @@ abstract class RDD[T: ClassTag](
}
/**
- * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
- * combine step happens locally on the master, equivalent to running a single reduce task.
+ * Return the count of each unique value in this RDD as a local map of (value, count) pairs.
+ *
+ * Note that this method should only be used if the resulting map is expected to be small, as
+ * the whole thing is loaded into the driver's memory.
+ * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
+ * returns an RDD[T, Long] instead of a map.
*/
def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = {
- if (elementClassTag.runtimeClass.isArray) {
- throw new SparkException("countByValue() does not support arrays")
- }
- // TODO: This should perhaps be distributed by default.
- val countPartition = (iter: Iterator[T]) => {
- val map = new OpenHashMap[T,Long]
- iter.foreach {
- t => map.changeValue(t, 1L, _ + 1L)
- }
- Iterator(map)
- }: Iterator[OpenHashMap[T,Long]]
- val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => {
- m2.foreach { case (key, value) =>
- m1.changeValue(key, value, _ + value)
- }
- m1
- }: OpenHashMap[T,Long]
- val myResult = mapPartitions(countPartition).reduce(mergeMaps)
- // Convert to a Scala mutable map
- val mutableResult = scala.collection.mutable.Map[T,Long]()
- myResult.foreach { case (k, v) => mutableResult.put(k, v) }
- mutableResult
+ map(value => (value, null)).countByKey()
}
/**
@@ -1079,15 +1066,17 @@ abstract class RDD[T: ClassTag](
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
- // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
+ // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
// interpolate the number of partitions we need to try, but overestimate it by 50%.
+ // We also cap the estimation in the end.
if (buf.size == 0) {
numPartsToTry = partsScanned * 4
} else {
- numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+ // the left side of max is >=1 whenever partsScanned >= 2
+ numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
+ numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
}
}
- numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - buf.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
@@ -1109,7 +1098,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Returns the top K (largest) elements from this RDD as defined by the specified
+ * Returns the top k (largest) elements from this RDD as defined by the specified
* implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example:
* {{{
* sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1)
@@ -1119,14 +1108,14 @@ abstract class RDD[T: ClassTag](
* // returns Array(6, 5)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse)
/**
- * Returns the first K (smallest) elements from this RDD as defined by the specified
+ * Returns the first k (smallest) elements from this RDD as defined by the specified
* implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]].
* For example:
* {{{
@@ -1137,7 +1126,7 @@ abstract class RDD[T: ClassTag](
* // returns Array(2, 3)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
@@ -1215,7 +1204,7 @@ abstract class RDD[T: ClassTag](
*/
def checkpoint() {
if (context.checkpointDir.isEmpty) {
- throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
checkpointData = Some(new RDDCheckpointData(this))
checkpointData.get.markForCheckpoint()
@@ -1322,7 +1311,7 @@ abstract class RDD[T: ClassTag](
def debugSelf (rdd: RDD[_]): Seq[String] = {
import Utils.bytesToString
- val persistence = storageLevel.description
+ val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else ""
val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info =>
" CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format(
info.numCachedPartitions, bytesToString(info.memSize),
@@ -1396,3 +1385,31 @@ abstract class RDD[T: ClassTag](
new JavaRDD(this)(elementClassTag)
}
}
+
+object RDD {
+
+ // The following implicit functions were in SparkContext before 1.2 and users had to
+ // `import SparkContext._` to enable them. Now we move them here to make the compiler find
+ // them automatically. However, we still keep the old functions in SparkContext for backward
+ // compatibility and forward to the following functions directly.
+
+ implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
+ (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
+ new PairRDDFunctions(rdd)
+ }
+
+ implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
+ rdd: RDD[(K, V)]) =
+ new SequenceFileRDDFunctions(rdd)
+
+ implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](
+ rdd: RDD[(K, V)]) =
+ new OrderedRDDFunctions[K, V, (K, V)](rdd)
+
+ implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
+
+ implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
+ new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index b097c30f8c231..9e8cee5331cf8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -21,8 +21,7 @@ import java.util.Random
import scala.reflect.ClassTag
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.{Partition, TaskContext}
@@ -53,9 +52,11 @@ private[spark] class SampledRDD[T: ClassTag](
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
- val poisson = new Poisson(frac, new DRand(split.seed))
+ val poisson = new PoissonDistribution(frac)
+ poisson.reseedRandomGenerator(split.seed)
+
firstParent[T].iterator(split.prev, context).flatMap { element =>
- val count = poisson.nextInt()
+ val count = poisson.sample()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
} else {
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 0c97eb0aaa51f..aece683ff3199 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
* Partition for UnionRDD.
@@ -48,7 +49,7 @@ private[spark] class UnionPartition[T: ClassTag](
override val index: Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
parentPartition = rdd.partitions(parentRddPartitionIndex)
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index f3d30f6c9b32f..996f2cd3f34a3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
+import org.apache.spark.util.Utils
private[spark] class ZippedPartitionsPartition(
idx: Int,
@@ -34,7 +35,7 @@ private[spark] class ZippedPartitionsPartition(
def partitions = partitionValues
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
partitionValues = rdds.map(rdd => rdd.partitions(idx))
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
index e2c301603b4a5..8c43a559409f2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
private[spark]
class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) {
- override def getPartitions: Array[Partition] = {
+ /** The start index of each partition. */
+ @transient private val startIndices: Array[Long] = {
val n = prev.partitions.size
- val startIndices: Array[Long] =
- if (n == 0) {
- Array[Long]()
- } else if (n == 1) {
- Array(0L)
- } else {
- prev.context.runJob(
- prev,
- Utils.getIteratorSize _,
- 0 until n - 1, // do not need to count the last partition
- false
- ).scanLeft(0L)(_ + _)
- }
+ if (n == 0) {
+ Array[Long]()
+ } else if (n == 1) {
+ Array(0L)
+ } else {
+ prev.context.runJob(
+ prev,
+ Utils.getIteratorSize _,
+ 0 until n - 1, // do not need to count the last partition
+ allowLocal = false
+ ).scanLeft(0L)(_ + _)
+ }
+ }
+
+ override def getPartitions: Array[Partition] = {
firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 8135cdbb4c31f..cb8ccfbdbdcbb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -124,6 +124,9 @@ class DAGScheduler(
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
+ /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
+ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
+
private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
@@ -446,7 +449,6 @@ class DAGScheduler(
}
// data structures based on StageId
stageIdToStage -= stageId
-
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
@@ -630,18 +632,17 @@ class DAGScheduler(
protected def runLocallyWithinThread(job: ActiveJob) {
var jobResult: JobResult = JobSucceeded
try {
- SparkEnv.set(env)
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
- new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
- TaskContext.setTaskContext(taskContext)
+ new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
+ TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
- TaskContext.unset()
+ TaskContextHelper.unset()
}
} catch {
case e: Exception =>
@@ -749,14 +750,15 @@ class DAGScheduler(
localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
if (shouldRunLocally) {
// Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties))
+ listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties))
runLocally(job)
} else {
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.resultOfJob = Some(job)
- listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,
- properties))
+ val stageIds = jobIdToStageIds(jobId).toArray
+ val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
+ listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties))
submitStage(finalStage)
}
}
@@ -899,6 +901,34 @@ class DAGScheduler(
}
}
+ /** Merge updates from a task to our local accumulator values */
+ private def updateAccumulators(event: CompletionEvent): Unit = {
+ val task = event.task
+ val stage = stageIdToStage(task.stageId)
+ if (event.accumUpdates != null) {
+ try {
+ Accumulators.add(event.accumUpdates)
+ event.accumUpdates.foreach { case (id, partialValue) =>
+ val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
+ // To avoid UI cruft, ignore cases where value wasn't updated
+ if (acc.name.isDefined && partialValue != acc.zero) {
+ val name = acc.name.get
+ val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
+ val stringValue = Accumulators.stringifyValue(acc.value)
+ stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
+ event.taskInfo.accumulables +=
+ AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
+ }
+ }
+ } catch {
+ // If we see an exception during accumulator update, just log the
+ // error and move on.
+ case e: Exception =>
+ logError(s"Failed to update accumulators for $task", e)
+ }
+ }
+ }
+
/**
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
@@ -939,27 +969,6 @@ class DAGScheduler(
}
event.reason match {
case Success =>
- if (event.accumUpdates != null) {
- try {
- Accumulators.add(event.accumUpdates)
- event.accumUpdates.foreach { case (id, partialValue) =>
- val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
- // To avoid UI cruft, ignore cases where value wasn't updated
- if (acc.name.isDefined && partialValue != acc.zero) {
- val name = acc.name.get
- val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
- val stringValue = Accumulators.stringifyValue(acc.value)
- stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
- event.taskInfo.accumulables +=
- AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
- }
- }
- } catch {
- // If we see an exception during accumulator update, just log the error and move on.
- case e: Exception =>
- logError(s"Failed to update accumulators for $task", e)
- }
- }
listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
event.reason, event.taskInfo, event.taskMetrics))
stage.pendingTasks -= task
@@ -968,6 +977,7 @@ class DAGScheduler(
stage.resultOfJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
+ updateAccumulators(event)
job.finished(rt.outputId) = true
job.numFinished += 1
// If the whole job has finished, remove it
@@ -992,6 +1002,7 @@ class DAGScheduler(
}
case smt: ShuffleMapTask =>
+ updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
@@ -1051,7 +1062,7 @@ class DAGScheduler(
logInfo("Resubmitted " + task + ", so marking it as still running")
stage.pendingTasks += task
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
@@ -1061,11 +1072,13 @@ class DAGScheduler(
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some("Fetch failure"))
+ markStageAsFinished(failedStage, Some(failureMessage))
runningStages -= failedStage
}
- if (failedStages.isEmpty && eventProcessActor != null) {
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty && eventProcessActor != null) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
@@ -1078,7 +1091,6 @@ class DAGScheduler(
}
failedStages += failedStage
failedStages += mapStage
-
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
@@ -1087,10 +1099,10 @@ class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, Some(task.epoch))
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
- case ExceptionFailure(className, description, stackTrace, metrics) =>
+ case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
@@ -1107,25 +1119,35 @@ class DAGScheduler(
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
+ * We will also assume that we've lost all shuffle blocks associated with the executor if the
+ * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed
+ * occurred, in which case we presume all shuffle data related to this executor to be lost.
+ *
* Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ private[scheduler] def handleExecutorLost(
+ execId: String,
+ fetchFailed: Boolean,
+ maybeEpoch: Option[Long] = None) {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)
- // TODO: This will be really slow if we keep accumulating shuffle map stages
- for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnExecutor(execId)
- val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
- mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
- }
- if (shuffleToMapStage.isEmpty) {
- mapOutputTracker.incrementEpoch()
+
+ if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnExecutor(execId)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
+ }
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementEpoch()
+ }
+ clearCacheLocs()
}
- clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
@@ -1383,7 +1405,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.handleExecutorAdded(execId, host)
case ExecutorLost(execId) =>
- dagScheduler.handleExecutorLost(execId)
+ dagScheduler.handleExecutorLost(execId, fetchFailed = false)
case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 100c9ba9b7809..597dbc884913c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -142,7 +142,7 @@ private[spark] object EventLoggingListener extends Logging {
val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
- val LOG_FILE_PERMISSIONS = FsPermission.createImmutable(Integer.parseInt("770", 8).toShort)
+ val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
// A cache for compression codecs to avoid creating the same codec many times
private val codecMap = new mutable.HashMap[String, CompressionCodec]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 54904bffdf10b..3bb54855bae44 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" INPUT_BYTES=" + metrics.bytesRead
case None => ""
}
+ val outputMetrics = taskMetrics.outputMetrics match {
+ case Some(metrics) =>
+ " OUTPUT_BYTES=" + metrics.bytesWritten
+ case None => ""
+ }
val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
@@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime
case None => ""
}
- stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics +
+ stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics +
shuffleReadMetrics + writeMetrics)
}
@@ -215,7 +220,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + taskEnd.stageId
stageLogInfo(taskEnd.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index e25096ea92d70..01d5943d777f3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,7 +19,10 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import org.roaringbitmap.RoaringBitmap
+
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
/**
* Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
@@ -29,7 +32,12 @@ private[spark] sealed trait MapStatus {
/** Location where this task was run. */
def location: BlockManagerId
- /** Estimated size for the reduce block, in bytes. */
+ /**
+ * Estimated size for the reduce block, in bytes.
+ *
+ * If a block is non-empty, then this method MUST return a non-zero size. This invariant is
+ * necessary for correctness, since block fetchers are allowed to skip zero-size blocks.
+ */
def getSizeForBlock(reduceId: Int): Long
}
@@ -38,7 +46,7 @@ private[spark] object MapStatus {
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > 2000) {
- new HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(loc, uncompressedSizes)
} else {
new CompressedMapStatus(loc, uncompressedSizes)
}
@@ -98,13 +106,13 @@ private[spark] class CompressedMapStatus(
MapStatus.decompressSize(compressedSizes(reduceId))
}
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
val len = in.readInt()
compressedSizes = new Array[Byte](len)
@@ -112,35 +120,80 @@ private[spark] class CompressedMapStatus(
}
}
-
/**
- * A [[MapStatus]] implementation that only stores the average size of the blocks.
+ * A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
+ * plus a bitmap for tracking which blocks are non-empty. During serialization, this bitmap
+ * is compressed.
*
- * @param loc location where the task is being executed.
- * @param avgSize average size of all the blocks
+ * @param loc location where the task is being executed
+ * @param numNonEmptyBlocks the number of non-empty blocks
+ * @param emptyBlocks a bitmap tracking which blocks are empty
+ * @param avgSize average size of the non-empty blocks
*/
-private[spark] class HighlyCompressedMapStatus(
+private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
+ private[this] var numNonEmptyBlocks: Int,
+ private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long)
extends MapStatus with Externalizable {
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.sum / uncompressedSizes.length)
- }
+ // loc could be null when the default constructor is called during deserialization
+ require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0,
+ "Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, 0L) // For deserialization only
+ protected def this() = this(null, -1, null, -1) // For deserialization only
override def location: BlockManagerId = loc
- override def getSizeForBlock(reduceId: Int): Long = avgSize
+ override def getSizeForBlock(reduceId: Int): Long = {
+ if (emptyBlocks.contains(reduceId)) {
+ 0
+ } else {
+ avgSize
+ }
+ }
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
+ emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
+ emptyBlocks = new RoaringBitmap()
+ emptyBlocks.readExternal(in)
avgSize = in.readLong()
}
}
+
+private[spark] object HighlyCompressedMapStatus {
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ // We must keep track of which blocks are empty so that we don't report a zero-sized
+ // block as being non-empty (or vice-versa) when using the average block size.
+ var i = 0
+ var numNonEmptyBlocks: Int = 0
+ var totalSize: Long = 0
+ // From a compression standpoint, it shouldn't matter whether we track empty or non-empty
+ // blocks. From a performance standpoint, we benefit from tracking empty blocks because
+ // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
+ val emptyBlocks = new RoaringBitmap()
+ val totalNumBlocks = uncompressedSizes.length
+ while (i < totalNumBlocks) {
+ var size = uncompressedSizes(i)
+ if (size > 0) {
+ numNonEmptyBlocks += 1
+ totalSize += size
+ } else {
+ emptyBlocks.add(i)
+ }
+ i += 1
+ }
+ val avgSize = if (numNonEmptyBlocks > 0) {
+ totalSize / numNonEmptyBlocks
+ } else {
+ 0
+ }
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 86afe3bd5265f..b62b0c1312693 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -56,8 +56,15 @@ case class SparkListenerTaskEnd(
extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null)
- extends SparkListenerEvent
+case class SparkListenerJobStart(
+ jobId: Int,
+ stageInfos: Seq[StageInfo],
+ properties: Properties = null)
+ extends SparkListenerEvent {
+ // Note: this is here for backwards-compatibility with older versions of this event which
+ // only stored stageIds and not StageInfos:
+ val stageIds: Seq[Int] = stageInfos.map(_.stageId)
+}
@DeveloperApi
case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 071568cdfb429..cc13f57a49b89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -102,6 +102,11 @@ private[spark] class Stage(
}
}
+ /**
+ * Removes all shuffle outputs associated with this executor. Note that this will also remove
+ * outputs which are served by an external shuffle server (if one exists), as they are still
+ * registered with this execId.
+ */
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
@@ -131,4 +136,9 @@ private[spark] class Stage(
override def toString = "Stage " + id
override def hashCode(): Int = id
+
+ override def equals(other: Any): Boolean = other match {
+ case stage: Stage => stage != null && stage.id == id
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index c6e47c84a0cb2..2552d03d18d06 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
-import org.apache.spark.TaskContext
+import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
@@ -45,8 +45,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
- context = new TaskContext(stageId, partitionId, attemptId, false)
- TaskContext.setTaskContext(context)
+ context = new TaskContextImpl(stageId, partitionId, attemptId, false)
+ TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
@@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runTask(context)
} finally {
context.markTaskCompleted()
- TaskContext.unset()
+ TaskContextHelper.unset()
}
}
@@ -70,7 +70,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
var metrics: Option[TaskMetrics] = None
// Task context, to be initialized in run().
- @transient protected var context: TaskContext = _
+ @transient protected var context: TaskContextImpl = _
// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index d49d8fb887007..1f114a0207f7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -31,8 +31,8 @@ import org.apache.spark.util.Utils
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
-private[spark]
-case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+ extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
@@ -42,7 +42,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
def this() = this(null.asInstanceOf[ByteBuffer], null, null)
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(valueBytes.remaining);
Utils.writeByteBuffer(valueBytes, out)
@@ -55,7 +55,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
out.writeObject(metrics)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 3f345ceeaaf7a..819b51e12ad8c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) =>
+ val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case directResult: DirectTaskResult[_] =>
+ if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
+ return
+ }
+ (directResult, serializedData.limit())
+ case IndirectTaskResult(blockId, size) =>
+ if (!taskSetManager.canFetchMoreResults(size)) {
+ // dropped by executor if size is larger than maxResultSize
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ return
+ }
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
@@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get)
sparkEnv.blockManager.master.removeBlock(blockId)
- deserializedResult
+ (deserializedResult, size)
}
- result.metrics.resultSize = serializedData.limit()
+
+ result.metrics.resultSize = size
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
@@ -93,7 +103,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
} catch {
case cnd: ClassNotFoundException =>
- // Log an error but keep going here -- the task failed, so not catastropic if we can't
+ // Log an error but keep going here -- the task failed, so not catastrophic if we can't
// deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index a129a434c9a1a..f095915352b17 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockManagerId
/**
* Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
- * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks
* for a single SparkContext. These schedulers get sets of tasks submitted to them from the
* DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
* them, retrying if there are failures, and mitigating stragglers. They return events to the
@@ -41,7 +41,7 @@ private[spark] trait TaskScheduler {
// Invoked after system has successfully initialized (typically in spark context).
// Yarn uses this to bootstrap allocation of resources based on preferred locations,
- // wait for slave registerations, etc.
+ // wait for slave registrations, etc.
def postStartHook() { }
// Disconnect from the cluster.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 4dc550413c13c..cd3c015321e85 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -34,7 +34,6 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.util.Utils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
-import akka.actor.Props
/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
@@ -216,13 +215,12 @@ private[spark] class TaskSchedulerImpl(
* that tasks are balanced across the cluster.
*/
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
- SparkEnv.set(sc.env)
-
// Mark each slave as alive and remember its hostname
// Also track if new executor is added
var newExecAvail = false
for (o <- offers) {
executorIdToHost(o.executorId) = o.host
+ activeExecutorIds += o.executorId
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
executorAdded(o.executorId, o.host)
@@ -263,7 +261,6 @@ private[spark] class TaskSchedulerImpl(
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskIdToExecutorId(tid) = execId
- activeExecutorIds += execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index a6c23fc85a1b0..cabdc655f89bf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -23,13 +23,12 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
+import scala.math.{min, max}
import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{Clock, SystemClock}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
@@ -68,6 +67,9 @@ private[spark] class TaskSetManager(
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5)
+ // Limit of bytes for total size of results (default is 1GB)
+ val maxResultSize = Utils.getMaxResultSize(conf)
+
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
@@ -89,6 +91,8 @@ private[spark] class TaskSetManager(
var stageId = taskSet.stageId
var name = "TaskSet_" + taskSet.stageId.toString
var parent: Pool = null
+ var totalResultSize = 0L
+ var calculatedTasks = 0
val runningTasksSet = new HashSet[Long]
override def runningTasks = runningTasksSet.size
@@ -515,12 +519,33 @@ private[spark] class TaskSetManager(
index
}
+ /**
+ * Marks the task as getting result and notifies the DAG Scheduler
+ */
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(info)
}
+ /**
+ * Check whether has enough quota to fetch the result with `size` bytes
+ */
+ def canFetchMoreResults(size: Long): Boolean = synchronized {
+ totalResultSize += size
+ calculatedTasks += 1
+ if (maxResultSize > 0 && totalResultSize > maxResultSize) {
+ val msg = s"Total size of serialized results of ${calculatedTasks} tasks " +
+ s"(${Utils.bytesToString(totalResultSize)}) is bigger than spark.driver.maxResultSize " +
+ s"(${Utils.bytesToString(maxResultSize)})"
+ logError(msg)
+ abort(msg)
+ false
+ } else {
+ true
+ }
+ }
+
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
@@ -687,10 +712,11 @@ private[spark] class TaskSetManager(
addPendingTask(index, readding=true)
}
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage.
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage,
+ // and we are not using an external shuffle server which could serve the shuffle outputs.
// The reason is the next stage wouldn't be able to fetch the data from this dead executor
// so we would need to rerun these tasks on other executors.
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
@@ -706,7 +732,7 @@ private[spark] class TaskSetManager(
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
+ handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId))
}
// recalculate valid locality levels and waits when executor is lost
recomputeLocality()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index fb8160abc59db..1da6fe976da5b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -66,7 +66,19 @@ private[spark] object CoarseGrainedClusterMessages {
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
- case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String)
+ // Exchanged between the driver and the AM in Yarn client mode
+ case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String)
extends CoarseGrainedClusterMessage
+ // Messages exchanged between the driver and the cluster manager for executor allocation
+ // In Yarn mode, these are exchanged between the driver and the AM
+
+ case object RegisterClusterManager extends CoarseGrainedClusterMessage
+
+ // Request executors by specifying the new total number of executors desired
+ // This includes executors already pending or running
+ case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage
+
+ case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 59aed6b72fe42..88b196ac64368 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -31,7 +31,6 @@ import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState}
import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
-import org.apache.spark.ui.JettyUtils
/**
* A scheduler backend that waits for coarse grained executors to connect to it through Akka.
@@ -42,11 +41,12 @@ import org.apache.spark.ui.JettyUtils
* (spark.deploy.*).
*/
private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
+ // Total number of executors that are currently registered
var totalRegisteredExecutors = new AtomicInteger(0)
val conf = scheduler.sc.conf
private val timeout = AkkaUtils.askTimeout(conf)
@@ -61,10 +61,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000)
val createTime = System.currentTimeMillis()
+ private val executorDataMap = new HashMap[String, ExecutorData]
+
+ // Number of executors requested from the cluster manager that have not registered yet
+ private var numPendingExecutors = 0
+
+ // Executors we have requested the cluster manager to kill that have not died yet
+ private val executorsPendingToRemove = new HashSet[String]
+
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive {
override protected def log = CoarseGrainedSchedulerBackend.this.log
private val addressToExecutorId = new HashMap[Address, String]
- private val executorDataMap = new HashMap[String, ExecutorData]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -84,12 +91,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor
- executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address,
- Utils.parseHostPort(hostPort)._1, cores, cores))
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
+ val (host, _) = Utils.parseHostPort(hostPort)
+ val data = new ExecutorData(sender, sender.path.address, host, cores, cores)
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap.put(executorId, data)
+ if (numPendingExecutors > 0) {
+ numPendingExecutors -= 1
+ logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
+ }
+ }
makeOffers()
}
@@ -111,7 +127,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
makeOffers()
case KillTask(taskId, executorId, interruptThread) =>
- executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread)
+ executorDataMap.get(executorId) match {
+ case Some(executorInfo) =>
+ executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread)
+ case None =>
+ // Ignoring the task kill since the executor is not registered.
+ logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
+ }
case StopDriver =>
sender ! true
@@ -128,10 +150,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
removeExecutor(executorId, reason)
sender ! true
- case AddWebUIFilter(filterName, filterParams, proxyBase) =>
- addWebUIFilter(filterName, filterParams, proxyBase)
- sender ! true
-
case DisassociatedEvent(_, address, _) =>
addressToExecutorId.get(address).foreach(removeExecutor(_,
"remote Akka client disassociated"))
@@ -183,13 +201,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
}
// Remove a disconnected slave from the cluster
- def removeExecutor(executorId: String, reason: String) {
+ def removeExecutor(executorId: String, reason: String): Unit = {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
- executorDataMap -= executorId
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap -= executorId
+ executorsPendingToRemove -= executorId
+ }
totalCoreCount.addAndGet(-executorInfo.totalCores)
+ totalRegisteredExecutors.addAndGet(-1)
scheduler.executorLost(executorId, SlaveLost(reason))
- case None => logError(s"Asked to remove non existant executor $executorId")
+ case None => logError(s"Asked to remove non-existent executor $executorId")
}
}
}
@@ -274,21 +298,62 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
false
}
- // Add filters to the SparkUI
- def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) {
- if (proxyBase != null && proxyBase.nonEmpty) {
- System.setProperty("spark.ui.proxyBase", proxyBase)
- }
+ /**
+ * Return the number of executors currently registered with this backend.
+ */
+ def numExistingExecutors: Int = executorDataMap.size
+
+ /**
+ * Request an additional number of executors from the cluster manager.
+ * Return whether the request is acknowledged.
+ */
+ final def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
+ logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
+ logDebug(s"Number of pending executors is now $numPendingExecutors")
+ numPendingExecutors += numAdditionalExecutors
+ // Account for executors pending to be added or removed
+ val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
+ doRequestTotalExecutors(newTotal)
+ }
- val hasFilter = (filterName != null && filterName.nonEmpty &&
- filterParams != null && filterParams.nonEmpty)
- if (hasFilter) {
- logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
- conf.set("spark.ui.filters", filterName)
- filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
- scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
+ /**
+ * Request executors from the cluster manager by specifying the total number desired,
+ * including existing pending and running executors.
+ *
+ * The semantics here guarantee that we do not over-allocate executors for this application,
+ * since a later request overrides the value of any prior request. The alternative interface
+ * of requesting a delta of executors risks double counting new executors when there are
+ * insufficient resources to satisfy the first request. We make the assumption here that the
+ * cluster manager will eventually fulfill all requests when resources free up.
+ *
+ * Return whether the request is acknowledged.
+ */
+ protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ * Return whether the kill request is acknowledged.
+ */
+ final def killExecutors(executorIds: Seq[String]): Boolean = {
+ logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
+ val filteredExecutorIds = new ArrayBuffer[String]
+ executorIds.foreach { id =>
+ if (executorDataMap.contains(id)) {
+ filteredExecutorIds += id
+ } else {
+ logWarning(s"Executor to kill $id does not exist!")
+ }
}
+ executorsPendingToRemove ++= filteredExecutorIds
+ doKillExecutors(filteredExecutorIds)
}
+
+ /**
+ * Kill the given list of executors through the cluster manager.
+ * Return whether the kill request is acknowledged.
+ */
+ protected def doKillExecutors(executorIds: Seq[String]): Boolean = false
+
}
private[spark] object CoarseGrainedSchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index ed209d195ec9d..8c7de75600b5f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -51,7 +51,8 @@ private[spark] class SparkDeploySchedulerBackend(
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
- val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}")
+ val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}",
+ "{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
.map(Utils.splitCommandString).getOrElse(Seq.empty)
val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
new file mode 100644
index 0000000000000..50721b9d6cd6c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import akka.actor.{Actor, ActorRef, Props}
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.ui.JettyUtils
+import org.apache.spark.util.AkkaUtils
+
+/**
+ * Abstract Yarn scheduler backend that contains common logic
+ * between the client and cluster Yarn scheduler backends.
+ */
+private[spark] abstract class YarnSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) {
+
+ if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
+ minRegisteredRatio = 0.8
+ }
+
+ protected var totalExpectedExecutors = 0
+
+ private val yarnSchedulerActor: ActorRef =
+ actorSystem.actorOf(
+ Props(new YarnSchedulerActor),
+ name = YarnSchedulerBackend.ACTOR_NAME)
+
+ private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf)
+
+ /**
+ * Request executors from the ApplicationMaster by specifying the total number desired.
+ * This includes executors already pending or running.
+ */
+ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
+ AkkaUtils.askWithReply[Boolean](
+ RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout)
+ }
+
+ /**
+ * Request that the ApplicationMaster kill the specified executors.
+ */
+ override def doKillExecutors(executorIds: Seq[String]): Boolean = {
+ AkkaUtils.askWithReply[Boolean](
+ KillExecutors(executorIds), yarnSchedulerActor, askTimeout)
+ }
+
+ override def sufficientResourcesRegistered(): Boolean = {
+ totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
+ }
+
+ /**
+ * Add filters to the SparkUI.
+ */
+ private def addWebUIFilter(
+ filterName: String,
+ filterParams: Map[String, String],
+ proxyBase: String): Unit = {
+ if (proxyBase != null && proxyBase.nonEmpty) {
+ System.setProperty("spark.ui.proxyBase", proxyBase)
+ }
+
+ val hasFilter =
+ filterName != null && filterName.nonEmpty &&
+ filterParams != null && filterParams.nonEmpty
+ if (hasFilter) {
+ logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
+ conf.set("spark.ui.filters", filterName)
+ filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
+ scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
+ }
+ }
+
+ /**
+ * An actor that communicates with the ApplicationMaster.
+ */
+ private class YarnSchedulerActor extends Actor {
+ private var amActor: Option[ActorRef] = None
+
+ override def preStart(): Unit = {
+ // Listen for disassociation events
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ }
+
+ override def receive = {
+ case RegisterClusterManager =>
+ logInfo(s"ApplicationMaster registered as $sender")
+ amActor = Some(sender)
+
+ case r: RequestExecutors =>
+ amActor match {
+ case Some(actor) =>
+ sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ case None =>
+ logWarning("Attempted to request executors before the AM has registered!")
+ sender ! false
+ }
+
+ case k: KillExecutors =>
+ amActor match {
+ case Some(actor) =>
+ sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ case None =>
+ logWarning("Attempted to kill executors before the AM has registered!")
+ sender ! false
+ }
+
+ case AddWebUIFilter(filterName, filterParams, proxyBase) =>
+ addWebUIFilter(filterName, filterParams, proxyBase)
+ sender ! true
+
+ case d: DisassociatedEvent =>
+ if (amActor.isDefined && sender == amActor.get) {
+ logWarning(s"ApplicationMaster has disassociated: $d")
+ }
+ }
+ }
+}
+
+private[spark] object YarnSchedulerBackend {
+ val ACTOR_NAME = "YarnScheduler"
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 90828578cd88f..5289661eb896b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -31,6 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas
import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.util.Utils
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -92,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -120,16 +121,18 @@ private[spark] class CoarseMesosSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions")
+ val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "")
- val libraryPathOption = "spark.executor.extraLibraryPath"
- val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p")
- val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ // Set the environment variable through a command prefix
+ // to append to the existing value of the variable
+ val prefixEnv = conf.getOption("spark.executor.extraLibraryPath").map { p =>
+ Utils.libraryPathEnvPrefix(Seq(p))
+ }.getOrElse("")
environment.addVariables(
Environment.Variable.newBuilder()
.setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraOpts)
+ .setValue(extraJavaOpts)
.build())
sc.executorEnvs.foreach { case (key, value) =>
@@ -150,17 +153,18 @@ private[spark] class CoarseMesosSchedulerBackend(
if (uri == null) {
val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
- runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
+ prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue,
+ offer.getHostname, numCores, appId))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- ("cd %s*; " +
- "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d")
- .format(basename, driverUrl, offer.getSlaveId.getValue,
- offer.getHostname, numCores))
+ ("cd %s*; %s " +
+ "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s")
+ .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue,
+ offer.getHostname, numCores, appId))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
command.build()
@@ -238,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Build a Mesos resource protobuf object */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index b11786368e661..10e6886c16a4f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
@@ -98,15 +98,16 @@ private[spark] class MesosSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
- val extraLibraryPath = sc.conf.getOption("spark.executor.extraLibraryPath").map { lp =>
- s"-Djava.library.path=$lp"
- }
- val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("")
+
+ val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p =>
+ Utils.libraryPathEnvPrefix(Seq(p))
+ }.getOrElse("")
+
environment.addVariables(
Environment.Variable.newBuilder()
.setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraOpts)
+ .setValue(extraJavaOpts)
.build())
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
@@ -118,12 +119,13 @@ private[spark] class MesosSchedulerBackend(
.setEnvironment(environment)
val uri = sc.conf.get("spark.executor.uri", null)
if (uri == null) {
- command.setValue(new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath)
+ val executorPath = new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath
+ command.setValue("%s %s".format(prefixEnv, executorPath))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
- command.setValue("cd %s*; ./sbin/spark-executor".format(basename))
+ command.setValue("cd %s*; %s ./sbin/spark-executor".format(basename, prefixEnv))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
val cpus = Resource.newBuilder()
@@ -164,29 +166,16 @@ private[spark] class MesosSchedulerBackend(
execArgs
}
- private def setClassLoader(): ClassLoader = {
- val oldClassLoader = Thread.currentThread.getContextClassLoader
- Thread.currentThread.setContextClassLoader(classLoader)
- oldClassLoader
- }
-
- private def restoreClassLoader(oldClassLoader: ClassLoader) {
- Thread.currentThread.setContextClassLoader(oldClassLoader)
- }
-
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
appId = frameworkId.getValue
logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
}
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
@@ -198,6 +187,16 @@ private[spark] class MesosSchedulerBackend(
}
}
+ private def inClassLoader()(fun: => Unit) = {
+ val oldClassLoader = Thread.currentThread.getContextClassLoader
+ Thread.currentThread.setContextClassLoader(classLoader)
+ try {
+ fun
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
+ }
+ }
+
override def disconnected(d: SchedulerDriver) {}
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
@@ -208,66 +207,70 @@ private[spark] class MesosSchedulerBackend(
* tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- val oldClassLoader = setClassLoader()
- try {
- synchronized {
- // Build a big list of the offerable workers, and remember their indices so that we can
- // figure out which Offer to reply to for each worker
- val offerableWorkers = new ArrayBuffer[WorkerOffer]
- val offerableIndices = new HashMap[String, Int]
-
- def sufficientOffer(o: Offer) = {
- val mem = getResource(o.getResourcesList, "mem")
- val cpus = getResource(o.getResourcesList, "cpus")
- val slaveId = o.getSlaveId.getValue
- (mem >= MemoryUtils.calculateTotalMemory(sc) &&
- // need at least 1 for executor, 1 for task
- cpus >= 2 * scheduler.CPUS_PER_TASK) ||
- (slaveIdsWithExecutors.contains(slaveId) &&
- cpus >= scheduler.CPUS_PER_TASK)
- }
+ inClassLoader() {
+ // Fail-fast on offers we know will be rejected
+ val (usableOffers, unUsableOffers) = offers.partition { o =>
+ val mem = getResource(o.getResourcesList, "mem")
+ val cpus = getResource(o.getResourcesList, "cpus")
+ val slaveId = o.getSlaveId.getValue
+ // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK?
+ (mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ // need at least 1 for executor, 1 for task
+ cpus >= 2 * scheduler.CPUS_PER_TASK) ||
+ (slaveIdsWithExecutors.contains(slaveId) &&
+ cpus >= scheduler.CPUS_PER_TASK)
+ }
- for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) {
- val slaveId = offer.getSlaveId.getValue
- offerableIndices.put(slaveId, index)
- val cpus = if (slaveIdsWithExecutors.contains(slaveId)) {
- getResource(offer.getResourcesList, "cpus").toInt
- } else {
- // If the executor doesn't exist yet, subtract CPU for executor
- getResource(offer.getResourcesList, "cpus").toInt -
- scheduler.CPUS_PER_TASK
- }
- offerableWorkers += new WorkerOffer(
- offer.getSlaveId.getValue,
- offer.getHostname,
- cpus)
+ val workerOffers = usableOffers.map { o =>
+ val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
+ getResource(o.getResourcesList, "cpus").toInt
+ } else {
+ // If the executor doesn't exist yet, subtract CPU for executor
+ // TODO(pwendell): Should below just subtract "1"?
+ getResource(o.getResourcesList, "cpus").toInt -
+ scheduler.CPUS_PER_TASK
}
+ new WorkerOffer(
+ o.getSlaveId.getValue,
+ o.getHostname,
+ cpus)
+ }
+
+ val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
+
+ val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
- // Call into the TaskSchedulerImpl
- val taskLists = scheduler.resourceOffers(offerableWorkers)
-
- // Build a list of Mesos tasks for each slave
- val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]())
- for ((taskList, index) <- taskLists.zipWithIndex) {
- if (!taskList.isEmpty) {
- for (taskDesc <- taskList) {
- val slaveId = taskDesc.executorId
- val offerNum = offerableIndices(slaveId)
- slaveIdsWithExecutors += slaveId
- taskIdToSlaveId(taskDesc.taskId) = slaveId
- mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId))
- }
+ val slavesIdsOfAcceptedOffers = HashSet[String]()
+
+ // Call into the TaskSchedulerImpl
+ val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty)
+ acceptedOffers
+ .foreach { offer =>
+ offer.foreach { taskDesc =>
+ val slaveId = taskDesc.executorId
+ slaveIdsWithExecutors += slaveId
+ slavesIdsOfAcceptedOffers += slaveId
+ taskIdToSlaveId(taskDesc.taskId) = slaveId
+ mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
+ .add(createMesosTask(taskDesc, slaveId))
}
}
- // Reply to the offers
- val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
- for (i <- 0 until offers.size) {
- d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters)
- }
+ // Reply to the offers
+ val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
+
+ mesosTasks.foreach { case (slaveId, tasks) =>
+ d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
- } finally {
- restoreClassLoader(oldClassLoader)
+
+ // Decline offers that weren't used
+ // NOTE: This logic assumes that we only get a single offer for each host in a given batch
+ for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) {
+ d.declineOffer(o.getId)
+ }
+
+ // Decline offers we ruled out immediately
+ unUsableOffers.foreach(o => d.declineOffer(o.getId))
}
}
@@ -276,8 +279,7 @@ private[spark] class MesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Turn a Spark TaskDescription into a Mesos task */
@@ -307,8 +309,7 @@ private[spark] class MesosSchedulerBackend(
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
val tid = status.getTaskId.getValue.toLong
val state = TaskState.fromMesos(status.getState)
synchronized {
@@ -321,18 +322,13 @@ private[spark] class MesosSchedulerBackend(
}
}
scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
override def error(d: SchedulerDriver, message: String) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
logError("Mesos error: " + message)
scheduler.error(message)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
@@ -349,15 +345,12 @@ private[spark] class MesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
logInfo("Mesos slave lost: " + slaveId.getValue)
synchronized {
slaveIdsWithExecutors -= slaveId.getValue
}
scheduler.executorLost(slaveId.getValue, reason)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
@@ -372,6 +365,13 @@ private[spark] class MesosSchedulerBackend(
recordSlaveLost(d, slaveId, ExecutorExited(status))
}
+ override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {
+ driver.killTask(
+ TaskID.newBuilder()
+ .setValue(taskId.toString).build()
+ )
+ }
+
// TODO: query Mesos for number of cores
override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 58b78f041cd85..a2f1f14264a99 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import akka.actor.{Actor, ActorRef, Props}
-import org.apache.spark.{Logging, SparkEnv, TaskState}
+import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
@@ -47,11 +47,11 @@ private[spark] class LocalActor(
private var freeCores = totalCores
- private val localExecutorId = "localhost"
+ private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
private val localExecutorHostname = "localhost"
val executor = new Executor(
- localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true)
+ localExecutorId, localExecutorHostname, scheduler.conf.getAll, totalCores, isLocal = true)
override def receiveWithLogging = {
case ReviveOffers =>
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 554a33ce7f1a6..662a7b91248aa 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -117,11 +117,11 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
new JavaSerializerInstance(counterReset, classLoader)
}
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(counterReset)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
counterReset = in.readInt()
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index d6386f8c06fff..621a951c27d07 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -53,7 +53,18 @@ class KryoSerializer(conf: SparkConf)
private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024
private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false)
- private val registrator = conf.getOption("spark.kryo.registrator")
+ private val userRegistrator = conf.getOption("spark.kryo.registrator")
+ private val classesToRegister = conf.get("spark.kryo.classesToRegister", "")
+ .split(',')
+ .filter(!_.isEmpty)
+ .map { className =>
+ try {
+ Class.forName(className)
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Failed to load class to register with Kryo", e)
+ }
+ }
def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
@@ -80,22 +91,20 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
- // Allow the user to register their own classes by setting spark.kryo.registrator
- for (regCls <- registrator) {
- logDebug("Running user registrator: " + regCls)
- try {
- val reg = Class.forName(regCls, true, classLoader).newInstance()
- .asInstanceOf[KryoRegistrator]
-
- // Use the default classloader when calling the user registrator.
- Thread.currentThread.setContextClassLoader(classLoader)
- reg.registerClasses(kryo)
- } catch {
- case e: Exception =>
- throw new SparkException(s"Failed to invoke $regCls", e)
- } finally {
- Thread.currentThread.setContextClassLoader(oldClassLoader)
- }
+ try {
+ // Use the default classloader when calling the user registrator.
+ Thread.currentThread.setContextClassLoader(classLoader)
+ // Register classes given through spark.kryo.classesToRegister.
+ classesToRegister.foreach { clazz => kryo.register(clazz) }
+ // Allow the user to register their own classes by setting spark.kryo.registrator.
+ userRegistrator
+ .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
+ .foreach { reg => reg.registerClasses(kryo) }
+ } catch {
+ case e: Exception =>
+ throw new SparkException(s"Failed to register classes with Kryo", e)
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
}
// Register Chill's classes; we do this after our ranges and the user's own classes to let
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index a9144cdd97b8c..ca6e971d227fb 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -17,14 +17,14 @@
package org.apache.spark.serializer
-import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
+import java.io._
import java.nio.ByteBuffer
import scala.reflect.ClassTag
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
+import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
/**
* :: DeveloperApi ::
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 71c08e9d5a8c3..be184464e0ae9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.{FetchFailed, TaskEndReason}
+import org.apache.spark.util.Utils
/**
* Failed to fetch a shuffle block. The executor catches this exception and propagates it
@@ -30,13 +31,22 @@ private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
- reduceId: Int)
- extends Exception {
+ reduceId: Int,
+ message: String,
+ cause: Throwable = null)
+ extends Exception(message, cause) {
- override def getMessage: String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+ def this(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int,
+ cause: Throwable) {
+ this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
+ }
- def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
+ Utils.exceptionString(this))
}
/**
@@ -46,7 +56,4 @@ private[spark] class MetadataFetchFailedException(
shuffleId: Int,
reduceId: Int,
message: String)
- extends FetchFailedException(null, shuffleId, -1, reduceId) {
-
- override def getMessage: String = message
-}
+ extends FetchFailedException(null, shuffleId, -1, reduceId, message)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 439981d232349..7de2f9cbb2866 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -24,9 +24,10 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
-import org.apache.spark.{SparkEnv, SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
@@ -62,11 +63,14 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
-
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData().
private[spark]
class FileShuffleBlockManager(conf: SparkConf)
extends ShuffleBlockManager with Logging {
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
private lazy val blockManager = SparkEnv.get.blockManager
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
@@ -181,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf)
val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
if (segmentOpt.isDefined) {
val segment = segmentOpt.get
- return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length)
+ return new FileSegmentManagedBuffer(
+ transportConf, segment.file, segment.offset, segment.length)
}
}
throw new IllegalStateException("Failed to find shuffle block: " + blockId)
} else {
val file = blockManager.diskBlockManager.getFile(blockId)
- new FileSegmentManagedBuffer(file, 0, file.length)
+ new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index 4ab34336d3f01..b292587d37028 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -20,8 +20,11 @@ package org.apache.spark.shuffle
import java.io._
import java.nio.ByteBuffer
-import org.apache.spark.SparkEnv
-import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.storage._
/**
@@ -33,11 +36,15 @@ import org.apache.spark.storage._
* as the filename postfix for data file, and ".index" as the filename postfix for index file.
*
*/
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData().
private[spark]
-class IndexShuffleBlockManager extends ShuffleBlockManager {
+class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager {
private lazy val blockManager = SparkEnv.get.blockManager
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
/**
* Mapping to a single shuffleBlockId with reduce ID 0.
* */
@@ -101,10 +108,11 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
val in = new DataInputStream(new FileInputStream(indexFile))
try {
- in.skip(blockId.reduceId * 8)
+ ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
new FileSegmentManagedBuffer(
+ transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 63863cc0250a3..b521f0c7fc77e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -18,8 +18,7 @@
package org.apache.spark.shuffle
import java.nio.ByteBuffer
-
-import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
index b30e366d06006..292e48314ee10 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
@@ -24,6 +24,10 @@ private[spark] trait ShuffleReader[K, C] {
/** Read the combined key-values for this reduce task */
def read(): Iterator[Product2[K, C]]
- /** Close this reader */
- def stop(): Unit
+ /**
+ * Close this reader.
+ * TODO: Add this back when we make the ShuffleReader a developer API that others can implement
+ * (at which point this will likely be necessary).
+ */
+ // def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 6cf9305977a3c..e3e7434df45b0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.util.{Failure, Success, Try}
import org.apache.spark._
import org.apache.spark.serializer.Serializer
@@ -52,21 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
- case Some(block) => {
+ case Success(block) => {
block.asInstanceOf[Iterator[T]]
}
- case None => {
+ case Failure(e) => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId)
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block")
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
}
}
}
@@ -74,7 +75,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
- SparkEnv.get.blockTransferService,
+ SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 88a5f1e5ddf58..5baf45db45c17 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -66,7 +66,4 @@ private[spark] class HashShuffleReader[K, C](
aggregatedIter
}
}
-
- /** Close this reader */
- override def stop(): Unit = ???
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 746ed33b54c00..183a30373b28c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
- MapStatus(blockManager.blockManagerId, sizes)
+ MapStatus(blockManager.shuffleServerId, sizes)
}
private def revertWrites(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b727438ae7e47..bda30a56d808e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager {
- private val indexShuffleBlockManager = new IndexShuffleBlockManager()
+ private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf)
private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 927481b72cf4f..d75f9d7311fad 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
- mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
deleted file mode 100644
index 5b6d086630834..0000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-
-/**
- * An interface for providing data for blocks.
- *
- * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
- *
- * Aside from unit tests, [[BlockManager]] is the main class that implements this.
- */
-private[spark] trait BlockDataProvider {
- def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index a83a3f468ae5f..1f012941c85ab 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def name = "rdd_" + rddId + "_" + splitIndex
}
+// Format of the shuffle block ids (including data and index) should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData().
@DeveloperApi
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
@@ -83,9 +85,14 @@ case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
def name = "input-" + streamId + "-" + uniqueId
}
-/** Id associated with temporary data managed as blocks. Not serializable. */
-private[spark] case class TempBlockId(id: UUID) extends BlockId {
- def name = "temp_" + id
+/** Id associated with temporary local data managed as blocks. Not serializable. */
+private[spark] case class TempLocalBlockId(id: UUID) extends BlockId {
+ def name = "temp_local_" + id
+}
+
+/** Id associated with temporary shuffle data managed as blocks. Not serializable. */
+private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId {
+ def name = "temp_shuffle_" + id
}
// Intended only for testing purposes
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 3f5d06e1aeee7..308c59eda594d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,14 +17,12 @@
package org.apache.spark.storage
-import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
+import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.Random
@@ -35,11 +33,16 @@ import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
+import org.apache.spark.network.shuffle.ExternalShuffleClient
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._
-
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -54,6 +57,12 @@ private[spark] class BlockResult(
inputMetrics.bytesRead = bytes
}
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -63,11 +72,11 @@ private[spark] class BlockManager(
val conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService)
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager,
+ numUsableCores: Int)
extends BlockDataManager with Logging {
- blockTransferService.init(this)
-
val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -87,8 +96,37 @@ private[spark] class BlockManager(
new TachyonStore(this, tachyonBlockManager)
}
- val blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ private[spark]
+ val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
+
+ // Port used by the external shuffle service. In Yarn mode, this may be already be
+ // set through the Hadoop configuration as the server is launched in the Yarn NM.
+ private val externalShuffleServicePort =
+ Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
+
+ // Check that we're not using external shuffle service with consolidated shuffle files.
+ if (externalShuffleServiceEnabled
+ && conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ && shuffleManager.isInstanceOf[HashShuffleManager]) {
+ throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated"
+ + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or "
+ + " switch to sort-based shuffle.")
+ }
+
+ var blockManagerId: BlockManagerId = _
+
+ // Address of the server that serves this executor's shuffle files. This is either an external
+ // service, or just our own Executor's BlockManager.
+ private[spark] var shuffleServerId: BlockManagerId = _
+
+ // Client to read other executors' shuffle files. This is either an external service, or just the
+ // standard BlockTranserService to directly connect to other Executors.
+ private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
+ val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
+ new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
+ } else {
+ blockTransferService
+ }
// Whether to compress broadcast variables that are stored
private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
@@ -118,8 +156,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- initialize()
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -138,17 +174,66 @@ private[spark] class BlockManager(
conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService) = {
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager,
+ numUsableCores: Int) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, mapOutputTracker, shuffleManager, blockTransferService)
+ conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores)
}
/**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor.
+ * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+ * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+ * where it is only learned after registration with the TaskScheduler).
+ *
+ * This method initializes the BlockTransferService and ShuffleClient, registers with the
+ * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * service if configured.
*/
- private def initialize(): Unit = {
+ def initialize(appId: String): Unit = {
+ blockTransferService.init(this)
+ shuffleClient.init(appId)
+
+ blockManagerId = BlockManagerId(
+ executorId, blockTransferService.hostName, blockTransferService.port)
+
+ shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+
+ // Register Executors' configuration with the local shuffle service, if one should exist.
+ if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
+ registerWithExternalShuffleServer()
+ }
+ }
+
+ private def registerWithExternalShuffleServer() {
+ logInfo("Registering executor with local external shuffle service.")
+ val shuffleConfig = new ExecutorShuffleInfo(
+ diskBlockManager.localDirs.map(_.toString),
+ diskBlockManager.subDirsPerLocalDir,
+ shuffleManager.getClass.getName)
+
+ val MAX_ATTEMPTS = 3
+ val SLEEP_TIME_SECS = 5
+
+ for (i <- 1 to MAX_ATTEMPTS) {
+ try {
+ // Synchronous and will throw an exception if we cannot connect.
+ shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer(
+ shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig)
+ return
+ } catch {
+ case e: Exception if i < MAX_ATTEMPTS =>
+ logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ + s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
+ Thread.sleep(SLEEP_TIME_SECS * 1000)
+ }
+ }
}
/**
@@ -212,21 +297,20 @@ private[spark] class BlockManager(
}
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- override def getBlockData(blockId: String): Option[ManagedBuffer] = {
- val bid = BlockId(blockId)
- if (bid.isShuffle) {
- Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ if (blockId.isShuffle) {
+ shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
- val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
+ .asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
- Some(new NioByteBufferManagedBuffer(buffer))
+ new NioManagedBuffer(buffer)
} else {
- None
+ throw new BlockNotFoundException(blockId.toString)
}
}
}
@@ -234,8 +318,8 @@ private[spark] class BlockManager(
/**
* Put the block locally, using the given storage level.
*/
- override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(blockId, data.nioByteBuffer(), level)
}
/**
@@ -340,17 +424,6 @@ private[spark] class BlockManager(
locations
}
- /**
- * A short-circuited method to get blocks directly from disk. This is used for getting
- * shuffle blocks. It is safe to do so without a lock on block info since disk store
- * never deletes (recent) items.
- */
- def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
- val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
- val is = wrapForCompression(blockId, buf.inputStream())
- Some(serializer.newInstance().deserializeStream(is).asIterator)
- }
-
/**
* Get block from local block manager.
*/
@@ -520,7 +593,7 @@ private[spark] class BlockManager(
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
val data = blockTransferService.fetchBlockSync(
- loc.host, loc.port, blockId.toString).nioByteBuffer()
+ loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
if (data != null) {
if (asBlockResult) {
@@ -869,9 +942,9 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
- logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms"
- .format((System.currentTimeMillis - onePeerStartTime)))
+ peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
+ logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
+ .format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
peersForReplication -= peer
replicationFailed = false
@@ -1071,7 +1144,8 @@ private[spark] class BlockManager(
case _: ShuffleBlockId => compressShuffle
case _: BroadcastBlockId => compressBroadcast
case _: RDDBlockId => compressRdds
- case _: TempBlockId => compressShuffleSpill
+ case _: TempLocalBlockId => compressShuffleSpill
+ case _: TempShuffleBlockId => compressShuffle
case _ => false
}
}
@@ -1125,7 +1199,11 @@ private[spark] class BlockManager(
}
def stop(): Unit = {
- blockTransferService.stop()
+ blockTransferService.close()
+ if (shuffleClient ne blockTransferService) {
+ // Closing should be idempotent, but maybe not for the NioBlockTransferService.
+ shuffleClient.close()
+ }
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 142285094342c..b177a59c721df 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
@@ -59,15 +60,15 @@ class BlockManagerId private (
def port: Int = port_
- def isDriver: Boolean = (executorId == "")
+ def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER }
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeUTF(executorId_)
out.writeUTF(host_)
out.writeInt(port_)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
executorId_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index d08e1419e3e41..b63c7f191155c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -88,6 +88,10 @@ class BlockManagerMaster(
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ }
+
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 6a06257ed0c08..685b2e11440fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetPeers(blockManagerId) =>
sender ! getPeers(blockManagerId)
+ case GetActorSystemHostPortForExecutor(executorId) =>
+ sender ! getActorSystemHostPortForExecutor(executorId)
+
case GetMemoryStatus =>
sender ! memoryStatus
@@ -203,6 +206,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}
}
listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId))
+ logInfo(s"Removing block manager $blockManagerId")
}
private def expireDeadHosts() {
@@ -327,20 +331,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
- case Some(manager) =>
- // A block manager of the same executor already exists.
- // This should never happen. Let's just quit.
- logError("Got two different block manager registrations on " + id.executorId)
- System.exit(1)
+ case Some(oldId) =>
+ // A block manager of the same executor already exists, so remove it (assumed dead)
+ logError("Got two different block manager registrations on same executor - "
+ + s" will replace old one $oldId with new one $id")
+ removeExecutor(id.executorId)
case None =>
- blockManagerIdByExecutor(id.executorId) = id
}
-
- logInfo("Registering block manager %s with %s RAM".format(
- id.hostPort, Utils.bytesToString(maxMemSize)))
-
- blockManagerInfo(id) =
- new BlockManagerInfo(id, time, maxMemSize, slaveActor)
+ logInfo("Registering block manager %s with %s RAM, %s".format(
+ id.hostPort, Utils.bytesToString(maxMemSize), id))
+
+ blockManagerIdByExecutor(id.executorId) = id
+
+ blockManagerInfo(id) = new BlockManagerInfo(
+ id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
}
@@ -411,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
Seq.empty
}
}
+
+ /**
+ * Returns the hostname and port of an executor's actor system, based on the Akka address of its
+ * BlockManagerSlaveActor.
+ */
+ private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId);
+ host <- info.slaveActor.path.address.host;
+ port <- info.slaveActor.path.address.port
+ ) yield {
+ (host, port)
+ }
+ }
}
@DeveloperApi
@@ -457,16 +476,18 @@ private[spark] class BlockManagerInfo(
if (_blocks.containsKey(blockId)) {
// The block exists on the slave already.
- val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ val originalLevel: StorageLevel = blockStatus.storageLevel
+ val originalMemSize: Long = blockStatus.memSize
if (originalLevel.useMemory) {
- _remainingMem += memSize
+ _remainingMem += originalMemSize
}
}
if (storageLevel.isValid) {
/* isValid means it is either stored in-memory, on-disk or on-Tachyon.
- * But the memSize here indicates the data size in or dropped from memory,
+ * The memSize here indicates the data size in or dropped from memory,
* tachyonSize here indicates the data size in or dropped from Tachyon,
* and the diskSize here indicates the data size in or dropped to disk.
* They can be both larger than 0, when a block is dropped from memory to disk.
@@ -493,7 +514,6 @@ private[spark] class BlockManagerInfo(
val blockStatus: BlockStatus = _blocks.get(blockId)
_blocks.remove(blockId)
if (blockStatus.storageLevel.useMemory) {
- _remainingMem += blockStatus.memSize
logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
Utils.bytesToString(_remainingMem)))
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 3db5dd9774ae8..3f32099d08cc9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -21,6 +21,8 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import akka.actor.ActorRef
+import org.apache.spark.util.Utils
+
private[spark] object BlockManagerMessages {
//////////////////////////////////////////////////////////////////////////////////
// Messages from the master to slaves.
@@ -65,7 +67,7 @@ private[spark] object BlockManagerMessages {
def this() = this(null, null, null, 0, 0, 0) // For deserialization only
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
blockManagerId.writeExternal(out)
out.writeUTF(blockId.name)
storageLevel.writeExternal(out)
@@ -74,7 +76,7 @@ private[spark] object BlockManagerMessages {
out.writeLong(tachyonSize)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
blockManagerId = BlockManagerId(in)
blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
@@ -90,6 +92,8 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+ case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
case object StopBlockManagerMaster extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
index 9ef453605f4f1..81f5f2d31dbd8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -17,5 +17,4 @@
package org.apache.spark.storage
-
class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index a715594f198c2..58fba54710510 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
extends Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+ private[spark]
+ val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
* having really large inodes at the top level. */
- val localDirs: Array[File] = createLocalDirs(conf)
+ private[spark] val localDirs: Array[File] = createLocalDirs(conf)
if (localDirs.isEmpty) {
logError("Failed to create any local dir.")
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
@@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
addShutdownHook()
+ /** Looks up a file by hashing it into one of our local subdirectories. */
+ // This method should be kept in sync with
+ // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile().
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
@@ -98,11 +102,20 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
getAllFiles().map(f => BlockId(f.getName))
}
- /** Produces a unique block id and File suitable for intermediate results. */
- def createTempBlock(): (TempBlockId, File) = {
- var blockId = new TempBlockId(UUID.randomUUID())
+ /** Produces a unique block id and File suitable for storing local intermediate results. */
+ def createTempLocalBlock(): (TempLocalBlockId, File) = {
+ var blockId = new TempLocalBlockId(UUID.randomUUID())
while (getFile(blockId).exists()) {
- blockId = new TempBlockId(UUID.randomUUID())
+ blockId = new TempLocalBlockId(UUID.randomUUID())
+ }
+ (blockId, getFile(blockId))
+ }
+
+ /** Produces a unique block id and File suitable for storing shuffled intermediate results. */
+ def createTempShuffleBlock(): (TempShuffleBlockId, File) = {
+ var blockId = new TempShuffleBlockId(UUID.randomUUID())
+ while (getFile(blockId).exists()) {
+ blockId = new TempShuffleBlockId(UUID.randomUUID())
}
(blockId, getFile(blockId))
}
@@ -140,7 +153,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run(): Unit = Utils.logUncaughtExceptions {
logDebug("Shutdown hook called")
@@ -151,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
- localDirs.foreach { localDir =>
- if (localDir.isDirectory() && localDir.exists()) {
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case e: Exception =>
- logError(s"Exception while deleting local spark dir: $localDir", e)
+ // Only perform cleanup if an external service is not serving our shuffle files.
+ if (!blockManager.externalShuffleServiceEnabled) {
+ localDirs.foreach { localDir =>
+ if (localDir.isDirectory() && localDir.exists()) {
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while deleting local spark dir: $localDir", e)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index e9304f6bb45d0..8dadf6794039e 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.io.{IOException, File, FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
@@ -73,7 +73,21 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
val startTime = System.currentTimeMillis
val file = diskManager.getFile(blockId)
val outputStream = new FileOutputStream(file)
- blockManager.dataSerializeStream(blockId, outputStream, values)
+ try {
+ try {
+ blockManager.dataSerializeStream(blockId, outputStream, values)
+ } finally {
+ // Close outputStream here because it should be closed before file is deleted.
+ outputStream.close()
+ }
+ } catch {
+ case e: Throwable =>
+ if (file.exists()) {
+ file.delete()
+ }
+ throw e
+ }
+
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
@@ -96,7 +110,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
// For small files, directly read rather than memory map
if (length < minMemoryMapBytes) {
val buf = ByteBuffer.allocate(length.toInt)
- channel.read(buf, offset)
+ channel.position(offset)
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException("Reached EOF before filling buffer\n" +
+ s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
+ }
+ }
buf.flip()
Some(buf)
} else {
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 0a09c24d61879..71305a46bf570 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -56,6 +56,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
(maxMemory * unrollFraction).toLong
}
+ // Initial memory to request before unrolling any block
+ private val unrollMemoryThreshold: Long =
+ conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
+
+ if (maxMemory < unrollMemoryThreshold) {
+ logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " +
+ s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " +
+ s"memory. Please configure Spark with more memory.")
+ }
+
logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory)))
/** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */
@@ -132,8 +142,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
PutResult(res.size, res.data, droppedBlocks)
case Right(iteratorValues) =>
// Not enough space to unroll this block; drop to disk if applicable
- logWarning(s"Not enough space to store block $blockId in memory! " +
- s"Free memory is $freeMemory bytes.")
if (level.useDisk && allowPersistToDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues)
@@ -215,7 +223,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Whether there is still enough memory for us to continue unrolling this block
var keepUnrolling = true
// Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing.
- val initialMemoryThreshold = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
+ val initialMemoryThreshold = unrollMemoryThreshold
// How often to check whether we need to request more memory
val memoryCheckPeriod = 16
// Memory currently reserved by this thread for this particular unrolling operation
@@ -230,6 +238,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Request enough memory to begin unrolling
keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold)
+ if (!keepUnrolling) {
+ logWarning(s"Failed to reserve initial memory threshold of " +
+ s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
+ }
+
// Unroll this block safely, checking whether we have exceeded our threshold periodically
try {
while (values.hasNext && keepUnrolling) {
@@ -265,6 +278,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
Left(vector.toArray)
} else {
// We ran out of space while unrolling the values for this block
+ logUnrollFailureMessage(blockId, vector.estimateSize())
Right(vector.iterator ++ values)
}
@@ -424,7 +438,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Reserve additional memory for unrolling blocks used by this thread.
* Return whether the request is granted.
*/
- private[spark] def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
+ def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
accountingLock.synchronized {
val granted = freeMemory > currentUnrollMemory + memory
if (granted) {
@@ -439,7 +453,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Release memory used by this thread for unrolling blocks.
* If the amount is not specified, remove the current thread's allocation altogether.
*/
- private[spark] def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
+ def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
val threadId = Thread.currentThread().getId
accountingLock.synchronized {
if (memory < 0) {
@@ -457,16 +471,50 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
/**
* Return the amount of memory currently occupied for unrolling blocks across all threads.
*/
- private[spark] def currentUnrollMemory: Long = accountingLock.synchronized {
+ def currentUnrollMemory: Long = accountingLock.synchronized {
unrollMemoryMap.values.sum
}
/**
* Return the amount of memory currently occupied for unrolling blocks by this thread.
*/
- private[spark] def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
+ def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L)
}
+
+ /**
+ * Return the number of threads currently unrolling blocks.
+ */
+ def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
+
+ /**
+ * Log information about current memory usage.
+ */
+ def logMemoryUsage(): Unit = {
+ val blocksMemory = currentMemory
+ val unrollMemory = currentUnrollMemory
+ val totalMemory = blocksMemory + unrollMemory
+ logInfo(
+ s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " +
+ s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " +
+ s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " +
+ s"Storage limit = ${Utils.bytesToString(maxMemory)}."
+ )
+ }
+
+ /**
+ * Log a warning for failing to unroll a block.
+ *
+ * @param blockId ID of the block we are trying to unroll.
+ * @param finalVectorSize Final size of the vector before unrolling failed.
+ */
+ def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = {
+ logWarning(
+ s"Not enough space to cache $blockId in memory! " +
+ s"(computed ${Utils.bytesToString(finalVectorSize)} so far)"
+ )
+ logMemoryUsage()
+ }
}
private[spark] case class ResultWithDroppedBlocks(
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 71b276b5f18e4..83170f7c5a4ab 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -19,15 +19,15 @@ package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
+import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
+import scala.util.{Failure, Success, Try}
-import org.apache.spark.{TaskContext, Logging}
-import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{CompletionIterator, Utils}
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -40,8 +40,8 @@ import org.apache.spark.util.Utils
* using too much memory.
*
* @param context [[TaskContext]], used for metrics update
- * @param blockTransferService [[BlockTransferService]] for fetching remote blocks
- * @param blockManager [[BlockManager]] for reading local blocks
+ * @param shuffleClient [[ShuffleClient]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
@@ -51,12 +51,12 @@ import org.apache.spark.util.Utils
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
- blockTransferService: BlockTransferService,
+ shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
+ extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
import ShuffleBlockFetcherIterator._
@@ -88,17 +88,53 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
+ /**
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
+ @volatile private[this] var currentResult: FetchResult = null
+
+ /**
+ * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ * the number of bytes in flight is limited to maxBytesInFlight.
+ */
private[this] val fetchRequests = new Queue[FetchRequest]
- // Current bytes in flight from our requests
+ /** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ /**
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
+ @volatile private[this] var isZombie = false
+
initialize()
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[this] def cleanup() {
+ isZombie = true
+ // Release the current buffer if necessary
+ currentResult match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
+ }
+
+ // Release buffers in the results queue
+ val iter = results.iterator()
+ while (iter.hasNext) {
+ val result = iter.next()
+ result match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
+ }
+ }
+ }
+
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
@@ -108,26 +144,26 @@ final class ShuffleBlockFetcherIterator(
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
- blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
+ val address = req.address
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- results.put(new FetchResult(BlockId(blockId), sizeMap(blockId),
- () => serializer.newInstance().deserializeStream(
- blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator
- ))
- shuffleMetrics.remoteBytesRead += data.size
- shuffleMetrics.remoteBlocksFetched += 1
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ if (!isZombie) {
+ // Increment the ref count because we need to pass this to a different thread.
+ // This needs to be released after use.
+ buf.retain()
+ results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ shuffleMetrics.remoteBytesRead += buf.size
+ shuffleMetrics.remoteBlocksFetched += 1
+ }
+ logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
- override def onBlockFetchFailure(e: Throwable): Unit = {
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- // Note that there is a chance that some blocks have been fetched successfully, but we
- // still add them to the failed queue. This is fine because when the caller see a
- // FetchFailedException, it is going to fail the entire task anyway.
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
+ results.put(new FailureFetchResult(BlockId(blockId), e))
}
}
)
@@ -138,7 +174,7 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+ logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
@@ -148,7 +184,7 @@ final class ShuffleBlockFetcherIterator(
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
- if (address == blockManager.blockManagerId) {
+ if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
@@ -185,26 +221,34 @@ final class ShuffleBlockFetcherIterator(
remoteRequests
}
+ /**
+ * Fetch the local blocks while we are fetching remote blocks. This is ok because
+ * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
private[this] def fetchLocalBlocks() {
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- for (id <- localBlocks) {
+ val iter = localBlocks.iterator
+ while (iter.hasNext) {
+ val blockId = iter.next()
try {
+ val buf = blockManager.getBlockData(blockId)
shuffleMetrics.localBlocksFetched += 1
- results.put(new FetchResult(
- id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get))
- logDebug("Got local block " + id)
+ buf.retain()
+ results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
+ // If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(id, -1, null))
+ results.put(new FailureFetchResult(blockId, e))
return
}
}
}
private[this] def initialize(): Unit = {
+ // Add a task completion callback (called in both success case and failure case) to cleanup.
+ context.addTaskCompletionListener(_ => cleanup())
+
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
@@ -221,26 +265,44 @@ final class ShuffleBlockFetcherIterator(
// Get Local Blocks
fetchLocalBlocks()
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
- override def next(): (BlockId, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Try[Iterator[Any]]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
- val result = results.take()
+ currentResult = results.take()
+ val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
- if (!result.failed) {
- bytesInFlight -= result.size
+
+ result match {
+ case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case _ =>
}
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
+
+ val iteratorTry: Try[Iterator[Any]] = result match {
+ case FailureFetchResult(_, e) => Failure(e)
+ case SuccessFetchResult(blockId, _, buf) => {
+ val is = blockManager.wrapForCompression(blockId, buf.createInputStream())
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ Success(CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ buf.release()
+ }))
+ }
+ }
+
+ (result.blockId, iteratorTry)
}
}
@@ -254,18 +316,35 @@ object ShuffleBlockFetcherIterator {
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
*/
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
+ case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
/**
- * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * Result of a fetch from a remote block.
+ */
+ private[storage] sealed trait FetchResult {
+ val blockId: BlockId
+ }
+
+ /**
+ * Result of a fetch from a remote block successfully.
* @param blockId block id
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
- * @param deserialize closure to return the result in the form of an Iterator.
+ * @param buf [[ManagedBuffer]] for the content.
*/
- class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
+ private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ extends FetchResult {
+ require(buf != null)
+ require(size >= 0)
}
+
+ /**
+ * Result of a fetch from a remote block unsuccessfully.
+ * @param blockId block id
+ * @param e the failure exception
+ */
+ private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 1e35abaab5353..56edc4fe2e4ad 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -97,12 +98,12 @@ class StorageLevel private(
ret
}
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeByte(toInt)
out.writeByte(_replication)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val flags = in.readByte()
_useDisk = (flags & 8) != 0
_useMemory = (flags & 4) != 0
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index d9066f766476e..def49e80a3605 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import scala.collection.mutable
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
@@ -59,10 +60,9 @@ class StorageStatusListener extends SparkListener {
val info = taskEnd.taskInfo
val metrics = taskEnd.taskMetrics
if (info != null && metrics != null) {
- val execId = formatExecutorId(info.executorId)
val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
if (updatedBlocks.length > 0) {
- updateStorageStatus(execId, updatedBlocks)
+ updateStorageStatus(info.executorId, updatedBlocks)
}
}
}
@@ -88,13 +88,4 @@ class StorageStatusListener extends SparkListener {
}
}
- /**
- * In the local mode, there is a discrepancy between the executor ID according to the
- * task ("localhost") and that according to SparkEnv (""). In the UI, this
- * results in duplicate rows for the same executor. Thus, in this mode, we aggregate
- * these two rows and use the executor ID of "" to be consistent.
- */
- def formatExecutorId(execId: String): String = {
- if (execId == "localhost") "" else execId
- }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index 6908a59a79e60..af873034215a9 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -148,6 +148,7 @@ private[spark] class TachyonBlockManager(
logError("Exception while deleting tachyon spark dir: " + tachyonDir, e)
}
}
+ client.close()
}
})
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
index 932b5616043b4..233d1e2b7c616 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.IOException
import java.nio.ByteBuffer
+import com.google.common.io.ByteStreams
import tachyon.client.{ReadType, WriteType}
import org.apache.spark.Logging
@@ -105,25 +106,19 @@ private[spark] class TachyonStore(
return None
}
val is = file.getInStream(ReadType.CACHE)
- var buffer: ByteBuffer = null
+ assert (is != null)
try {
- if (is != null) {
- val size = file.length
- val bs = new Array[Byte](size.asInstanceOf[Int])
- val fetchSize = is.read(bs, 0, size.asInstanceOf[Int])
- buffer = ByteBuffer.wrap(bs)
- if (fetchSize != size) {
- logWarning(s"Failed to fetch the block $blockId from Tachyon: Size $size " +
- s"is not equal to fetched size $fetchSize")
- return None
- }
- }
+ val size = file.length
+ val bs = new Array[Byte](size.asInstanceOf[Int])
+ ByteStreams.readFully(is, bs)
+ Some(ByteBuffer.wrap(bs))
} catch {
case ioe: IOException =>
logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe)
- return None
+ None
+ } finally {
+ is.close()
}
- Some(buffer)
}
override def contains(blockId: BlockId): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
new file mode 100644
index 0000000000000..27ba9e18237b5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import java.util.{Timer, TimerTask}
+
+import org.apache.spark._
+
+/**
+ * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the
+ * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed
+ * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status
+ * of them will be combined together, showed in one line.
+ */
+private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
+
+ // Carrige return
+ val CR = '\r'
+ // Update period of progress bar, in milliseconds
+ val UPDATE_PERIOD = 200L
+ // Delay to show up a progress bar, in milliseconds
+ val FIRST_DELAY = 500L
+
+ // The width of terminal
+ val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) {
+ sys.env.get("COLUMNS").get.toInt
+ } else {
+ 80
+ }
+
+ var lastFinishTime = 0L
+ var lastUpdateTime = 0L
+ var lastProgressBar = ""
+
+ // Schedule a refresh thread to run periodically
+ private val timer = new Timer("refresh progress", true)
+ timer.schedule(new TimerTask{
+ override def run() {
+ refresh()
+ }
+ }, FIRST_DELAY, UPDATE_PERIOD)
+
+ /**
+ * Try to refresh the progress bar in every cycle
+ */
+ private def refresh(): Unit = synchronized {
+ val now = System.currentTimeMillis()
+ if (now - lastFinishTime < FIRST_DELAY) {
+ return
+ }
+ val stageIds = sc.statusTracker.getActiveStageIds()
+ val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1)
+ .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId())
+ if (stages.size > 0) {
+ show(now, stages.take(3)) // display at most 3 stages in same time
+ }
+ }
+
+ /**
+ * Show progress bar in console. The progress bar is displayed in the next line
+ * after your last output, keeps overwriting itself to hold in one line. The logging will follow
+ * the progress bar, then progress bar will be showed in next line without overwrite logs.
+ */
+ private def show(now: Long, stages: Seq[SparkStageInfo]) {
+ val width = TerminalWidth / stages.size
+ val bar = stages.map { s =>
+ val total = s.numTasks()
+ val header = s"[Stage ${s.stageId()}:"
+ val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]"
+ val w = width - header.size - tailer.size
+ val bar = if (w > 0) {
+ val percent = w * s.numCompletedTasks() / total
+ (0 until w).map { i =>
+ if (i < percent) "=" else if (i == percent) ">" else " "
+ }.mkString("")
+ } else {
+ ""
+ }
+ header + bar + tailer
+ }.mkString("")
+
+ // only refresh if it's changed of after 1 minute (or the ssh connection will be closed
+ // after idle some time)
+ if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) {
+ System.err.print(CR + bar)
+ lastUpdateTime = now
+ }
+ lastProgressBar = bar
+ }
+
+ /**
+ * Clear the progress bar if showed.
+ */
+ private def clear() {
+ if (!lastProgressBar.isEmpty) {
+ System.err.printf(CR + " " * TerminalWidth + CR)
+ lastProgressBar = ""
+ }
+ }
+
+ /**
+ * Mark all the stages as finished, clear the progress bar if showed, then the progress will not
+ * interweave with output of jobs.
+ */
+ def finishAll(): Unit = synchronized {
+ clear()
+ lastFinishTime = System.currentTimeMillis()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index cccd59d122a92..176907dffa46a 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -21,60 +21,46 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.env.EnvironmentTab
-import org.apache.spark.ui.exec.ExecutorsTab
-import org.apache.spark.ui.jobs.JobProgressTab
-import org.apache.spark.ui.storage.StorageTab
+import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab}
+import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab}
+import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab}
+import org.apache.spark.ui.storage.{StorageListener, StorageTab}
/**
* Top level user interface for a Spark application.
*/
-private[spark] class SparkUI(
- val sc: SparkContext,
+private[spark] class SparkUI private (
+ val sc: Option[SparkContext],
val conf: SparkConf,
val securityManager: SecurityManager,
- val listenerBus: SparkListenerBus,
+ val environmentListener: EnvironmentListener,
+ val storageStatusListener: StorageStatusListener,
+ val executorsListener: ExecutorsListener,
+ val jobProgressListener: JobProgressListener,
+ val storageListener: StorageListener,
var appName: String,
- val basePath: String = "")
+ val basePath: String)
extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI")
with Logging {
- def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName)
- def this(conf: SparkConf, listenerBus: SparkListenerBus, appName: String, basePath: String) =
- this(null, conf, new SecurityManager(conf), listenerBus, appName, basePath)
-
- def this(
- conf: SparkConf,
- securityManager: SecurityManager,
- listenerBus: SparkListenerBus,
- appName: String,
- basePath: String) =
- this(null, conf, securityManager, listenerBus, appName, basePath)
-
- // If SparkContext is not provided, assume the associated application is not live
- val live = sc != null
-
- // Maintain executor storage status through Spark events
- val storageStatusListener = new StorageStatusListener
-
- initialize()
+ val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false)
/** Initialize all components of the server. */
def initialize() {
- listenerBus.addListener(storageStatusListener)
- val jobProgressTab = new JobProgressTab(this)
- attachTab(jobProgressTab)
+ attachTab(new JobsTab(this))
+ val stagesTab = new StagesTab(this)
+ attachTab(stagesTab)
attachTab(new StorageTab(this))
attachTab(new EnvironmentTab(this))
attachTab(new ExecutorsTab(this))
attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
- attachHandler(createRedirectHandler("/", "/stages", basePath = basePath))
+ attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath))
attachHandler(
- createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest))
- if (live) {
- sc.env.metricsSystem.getServletHandlers.foreach(attachHandler)
- }
+ createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest))
+ // If the UI is live, then serve
+ sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) }
}
+ initialize()
def getAppName = appName
@@ -83,11 +69,6 @@ private[spark] class SparkUI(
appName = name
}
- /** Register the given listener with the listener bus. */
- def registerListener(listener: SparkListener) {
- listenerBus.addListener(listener)
- }
-
/** Stop the server behind this web interface. Only valid after bind(). */
override def stop() {
super.stop()
@@ -116,4 +97,60 @@ private[spark] object SparkUI {
def getUIPort(conf: SparkConf): Int = {
conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT)
}
+
+ def createLiveUI(
+ sc: SparkContext,
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ jobProgressListener: JobProgressListener,
+ securityManager: SecurityManager,
+ appName: String): SparkUI = {
+ create(Some(sc), conf, listenerBus, securityManager, appName,
+ jobProgressListener = Some(jobProgressListener))
+ }
+
+ def createHistoryUI(
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ securityManager: SecurityManager,
+ appName: String,
+ basePath: String): SparkUI = {
+ create(None, conf, listenerBus, securityManager, appName, basePath)
+ }
+
+ /**
+ * Create a new Spark UI.
+ *
+ * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs.
+ * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the
+ * web UI will create and register its own JobProgressListener.
+ */
+ private def create(
+ sc: Option[SparkContext],
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ securityManager: SecurityManager,
+ appName: String,
+ basePath: String = "",
+ jobProgressListener: Option[JobProgressListener] = None): SparkUI = {
+
+ val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse {
+ val listener = new JobProgressListener(conf)
+ listenerBus.addListener(listener)
+ listener
+ }
+
+ val environmentListener = new EnvironmentListener
+ val storageStatusListener = new StorageStatusListener
+ val executorsListener = new ExecutorsListener(storageStatusListener)
+ val storageListener = new StorageListener(storageStatusListener)
+
+ listenerBus.addListener(environmentListener)
+ listenerBus.addListener(storageStatusListener)
+ listenerBus.addListener(executorsListener)
+ listenerBus.addListener(storageListener)
+
+ new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener,
+ executorsListener, _jobProgressListener, storageListener, appName, basePath)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 9ced9b8107ebf..6f446c5a95a0a 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,11 +24,28 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
+ val TASK_DESERIALIZATION_TIME =
+ """Time spent deserializating the task closure on the executor."""
+
val INPUT = "Bytes read from Hadoop or from Spark storage."
+ val OUTPUT = "Bytes written to Hadoop."
+
val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage."
val SHUFFLE_READ =
"""Bytes read from remote executors. Typically less than shuffle write bytes
because this does not include shuffle data read locally."""
+
+ val GETTING_RESULT_TIME =
+ """Time that the driver spends fetching task results from workers. If this is large, consider
+ decreasing the amount of data returned from each task."""
+
+ val RESULT_SERIALIZATION_TIME =
+ """Time spent serializing the task result on the executor before sending it back to the
+ driver."""
+
+ val GC_TIME =
+ """Time that the executor spent paused for Java garbage collection while the task was
+ running."""
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index f0006b42aee4f..09079bbd43f6f 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -20,12 +20,14 @@ package org.apache.spark.ui
import java.text.SimpleDateFormat
import java.util.{Locale, Date}
-import scala.xml.Node
+import scala.xml.{Node, Text}
+
import org.apache.spark.Logging
/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
- val TABLE_CLASS = "table table-bordered table-striped table-condensed sortable"
+ val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable"
+ val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
@@ -159,6 +161,8 @@ private[spark] object UIUtils extends Logging {
+
+
}
/** Returns a spark page with correctly formatted headers */
@@ -166,14 +170,19 @@ private[spark] object UIUtils extends Logging {
title: String,
content: => Seq[Node],
activeTab: SparkUITab,
- refreshInterval: Option[Int] = None): Seq[Node] = {
+ refreshInterval: Option[Int] = None,
+ helpText: Option[String] = None): Seq[Node] = {
val appName = activeTab.appName
+ val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..."
val header = activeTab.headerTabs.map { tab =>
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 5d88ca403a674..9be65a4a39a09 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -82,7 +82,7 @@ private[spark] abstract class WebUI(
}
/** Detach a handler from this UI. */
- def detachHandler(handler: ServletContextHandler) {
+ protected def detachHandler(handler: ServletContextHandler) {
handlers -= handler
serverInfo.foreach { info =>
info.rootHandler.removeHandler(handler)
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
index 0d158fbe638d3..f62260c6f6e1d 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
@@ -22,10 +22,8 @@ import org.apache.spark.scheduler._
import org.apache.spark.ui._
private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") {
- val listener = new EnvironmentListener
-
+ val listener = parent.environmentListener
attachPage(new EnvironmentPage(this))
- parent.registerListener(listener)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
new file mode 100644
index 0000000000000..c82730f524eb7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import java.net.URLDecoder
+import javax.servlet.http.HttpServletRequest
+
+import scala.util.Try
+import scala.xml.{Text, Node}
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") {
+
+ private val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val executorId = Option(request.getParameter("executorId")).map {
+ executorId =>
+ // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when
+ // running in yarn-cluster mode. `request.getParameter("executorId")` will return
+ // "%253Cdriver%253E". Therefore we need to decode it until we get the real id.
+ var id = executorId
+ var decodedId = URLDecoder.decode(id, "UTF-8")
+ while (id != decodedId) {
+ id = decodedId
+ decodedId = URLDecoder.decode(id, "UTF-8")
+ }
+ id
+ }.getOrElse {
+ return Text(s"Missing executorId parameter")
+ }
+ val time = System.currentTimeMillis()
+ val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
+
+ val content = maybeThreadDump.map { threadDump =>
+ val dumpRows = threadDump.map { thread =>
+
+ } else {
+ Seq.empty
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 61eb111cd9100..dd1c2b78c4094 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -26,10 +26,15 @@ import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.{SparkUI, SparkUITab}
private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") {
- val listener = new ExecutorsListener(parent.storageStatusListener)
+ val listener = parent.executorsListener
+ val sc = parent.sc
+ val threadDumpEnabled =
+ sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true)
- attachPage(new ExecutorsPage(this))
- parent.registerListener(listener)
+ attachPage(new ExecutorsPage(this, threadDumpEnabled))
+ if (threadDumpEnabled) {
+ attachPage(new ExecutorThreadDumpPage(this))
+ }
}
/**
@@ -43,20 +48,21 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
val executorToTasksFailed = HashMap[String, Int]()
val executorToDuration = HashMap[String, Long]()
val executorToInputBytes = HashMap[String, Long]()
+ val executorToOutputBytes = HashMap[String, Long]()
val executorToShuffleRead = HashMap[String, Long]()
val executorToShuffleWrite = HashMap[String, Long]()
def storageStatusList = storageStatusListener.storageStatusList
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
- val eid = formatExecutorId(taskStart.taskInfo.executorId)
+ val eid = taskStart.taskInfo.executorId
executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val info = taskEnd.taskInfo
if (info != null) {
- val eid = formatExecutorId(info.executorId)
+ val eid = info.executorId
executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1
executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration
taskEnd.reason match {
@@ -73,6 +79,10 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
executorToInputBytes(eid) =
executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead
}
+ metrics.outputMetrics.foreach { outputMetrics =>
+ executorToOutputBytes(eid) =
+ executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten
+ }
metrics.shuffleReadMetrics.foreach { shuffleRead =>
executorToShuffleRead(eid) =
executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead
@@ -85,6 +95,4 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
}
}
- // This addresses executor ID inconsistencies in the local mode
- private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
new file mode 100644
index 0000000000000..ea2d187a0e8e4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.xml.{Node, NodeSeq}
+
+import javax.servlet.http.HttpServletRequest
+
+import org.apache.spark.JobExecutionStatus
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.jobs.UIData.JobUIData
+
+/** Page showing list of all ongoing and recently finished jobs */
+private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
+ private val startTime: Option[Long] = parent.sc.map(_.startTime)
+ private val listener = parent.listener
+
+ private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = {
+ val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
+
+ val columns: Seq[Node] = {
+
{if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"}
+
Description
+
Submitted
+
Duration
+
Stages: Succeeded/Total
+
Tasks (for all stages): Succeeded/Total
+ }
+
+ def makeRow(job: JobUIData): Seq[Node] = {
+ val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max)
+ val lastStageData = lastStageInfo.flatMap { s =>
+ listener.stageIdToData.get((s.stageId, s.attemptId))
+ }
+ val isComplete = job.status == JobExecutionStatus.SUCCEEDED
+ val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)")
+ val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("")
+ val duration: Option[Long] = {
+ job.startTime.map { start =>
+ val end = job.endTime.getOrElse(System.currentTimeMillis())
+ end - start
+ }
+ }
+ val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
+ val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown")
+ val detailUrl =
+ "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId)
+
++ failedJobsTable
+
+ val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" +
+ " Click on a job's title to see information about the stages of tasks associated with" +
+ " the job."
+
+ UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText))
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
new file mode 100644
index 0000000000000..b0f8ca2ab0d3f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.{Node, NodeSeq}
+
+import org.apache.spark.scheduler.Schedulable
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+
+/** Page showing list of all ongoing and recently finished stages and pools */
+private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") {
+ private val sc = parent.sc
+ private val listener = parent.listener
+ private def isFairScheduler = parent.isFairScheduler
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val activeStages = listener.activeStages.values.toSeq
+ val completedStages = listener.completedStages.reverse.toSeq
+ val numCompletedStages = listener.numCompletedStages
+ val failedStages = listener.failedStages.reverse.toSeq
+ val numFailedStages = listener.numFailedStages
+ val now = System.currentTimeMillis
+
+ val activeStagesTable =
+ new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
+ val completedStagesTable =
+ new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
+ val failedStagesTable =
+ new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler)
+
+ // For now, pool information is only accessible in live UIs
+ val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable])
+ val poolTable = new PoolTable(pools, parent)
+
+ val summary: NodeSeq =
+
+
+ {if (sc.isDefined) {
+ // Total duration is not meaningful unless the UI is live
+
+ Total Duration:
+ {UIUtils.formatDuration(now - sc.get.startTime)}
+
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
new file mode 100644
index 0000000000000..77d36209c6048
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.collection.mutable
+import scala.xml.{NodeSeq, Node}
+
+import javax.servlet.http.HttpServletRequest
+
+import org.apache.spark.JobExecutionStatus
+import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+/** Page showing statistics and stage list for a given job */
+private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
+ private val listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val jobId = request.getParameter("id").toInt
+ val jobDataOption = listener.jobIdToData.get(jobId)
+ if (jobDataOption.isEmpty) {
+ val content =
+
+
No information to display for job {jobId}
+
+ return UIUtils.headerSparkPage(
+ s"Details for Job $jobId", content, parent)
+ }
+ val jobData = jobDataOption.get
+ val isComplete = jobData.status != JobExecutionStatus.RUNNING
+ val stages = jobData.stageIds.map { stageId =>
+ // This could be empty if the JobProgressListener hasn't received information about the
+ // stage or if the stage information has been garbage collected
+ listener.stageIdToInfo.getOrElse(stageId,
+ new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, "Unknown"))
+ }
+
+ val activeStages = mutable.Buffer[StageInfo]()
+ val completedStages = mutable.Buffer[StageInfo]()
+ // If the job is completed, then any pending stages are displayed as "skipped":
+ val pendingOrSkippedStages = mutable.Buffer[StageInfo]()
+ val failedStages = mutable.Buffer[StageInfo]()
+ for (stage <- stages) {
+ if (stage.submissionTime.isEmpty) {
+ pendingOrSkippedStages += stage
+ } else if (stage.completionTime.isDefined) {
+ if (stage.failureReason.isDefined) {
+ failedStages += stage
+ } else {
+ completedStages += stage
+ }
+ } else {
+ activeStages += stage
+ }
+ }
+
+ val activeStagesTable =
+ new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
+ val pendingOrSkippedStagesTable =
+ new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = false)
+ val completedStagesTable =
+ new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
+ val failedStagesTable =
+ new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler)
+
+ val shouldShowActiveStages = activeStages.nonEmpty
+ val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty
+ val shouldShowCompletedStages = completedStages.nonEmpty
+ val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty
+ val shouldShowFailedStages = failedStages.nonEmpty
+
+ val summary: NodeSeq =
+
++
+ failedStagesTable.toNodeSeq
+ }
+ UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index eaeb861f59e5a..72935beb3a34a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.jobs
-import scala.collection.mutable.{HashMap, ListBuffer}
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
@@ -40,29 +40,182 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
import JobProgressListener._
- // How many stages to remember
- val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
+ // Define a handful of type aliases so that data structures' types can serve as documentation.
+ // These type aliases are public because they're used in the types of public fields:
- // Map from stageId to StageInfo
- val activeStages = new HashMap[Int, StageInfo]
+ type JobId = Int
+ type StageId = Int
+ type StageAttemptId = Int
+ type PoolName = String
+ type ExecutorId = String
- // Map from (stageId, attemptId) to StageUIData
- val stageIdToData = new HashMap[(Int, Int), StageUIData]
+ // Jobs:
+ val activeJobs = new HashMap[JobId, JobUIData]
+ val completedJobs = ListBuffer[JobUIData]()
+ val failedJobs = ListBuffer[JobUIData]()
+ val jobIdToData = new HashMap[JobId, JobUIData]
+ // Stages:
+ val activeStages = new HashMap[StageId, StageInfo]
val completedStages = ListBuffer[StageInfo]()
+ val skippedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()
+ val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData]
+ val stageIdToInfo = new HashMap[StageId, StageInfo]
+ val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]]
+ val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]()
+ // Total of completed and failed stages that have ever been run. These may be greater than
+ // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than
+ // JobProgressListener's retention limits.
+ var numCompletedStages = 0
+ var numFailedStages = 0
+
+ // Misc:
+ val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]()
+ def blockManagerIds = executorIdToBlockManagerId.values.toSeq
- // Map from pool name to a hash map (map from stage id to StageInfo).
- val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
+ var schedulingMode: Option[SchedulingMode] = None
- val executorIdToBlockManagerId = HashMap[String, BlockManagerId]()
+ // To limit the total memory usage of JobProgressListener, we only track information for a fixed
+ // number of non-active jobs and stages (there is no limit for active jobs and stages):
- var schedulingMode: Option[SchedulingMode] = None
+ val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
+ val retainedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS)
+
+ // We can test for memory leaks by ensuring that collections that track non-active jobs and
+ // stages do not grow without bound and that collections for active jobs/stages eventually become
+ // empty once Spark is idle. Let's partition our collections into ones that should be empty
+ // once Spark is idle and ones that should have a hard- or soft-limited sizes.
+ // These methods are used by unit tests, but they're defined here so that people don't forget to
+ // update the tests when adding new collections. Some collections have multiple levels of
+ // nesting, etc, so this lets us customize our notion of "size" for each structure:
+
+ // These collections should all be empty once Spark is idle (no active stages / jobs):
+ private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = {
+ Map(
+ "activeStages" -> activeStages.size,
+ "activeJobs" -> activeJobs.size,
+ "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum,
+ "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum
+ )
+ }
- def blockManagerIds = executorIdToBlockManagerId.values.toSeq
+ // These collections should stop growing once we have run at least `spark.ui.retainedStages`
+ // stages and `spark.ui.retainedJobs` jobs:
+ private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = {
+ Map(
+ "completedJobs" -> completedJobs.size,
+ "failedJobs" -> failedJobs.size,
+ "completedStages" -> completedStages.size,
+ "skippedStages" -> skippedStages.size,
+ "failedStages" -> failedStages.size
+ )
+ }
+
+ // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to
+ // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings:
+ private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = {
+ Map(
+ "jobIdToData" -> jobIdToData.size,
+ "stageIdToData" -> stageIdToData.size,
+ "stageIdToStageInfo" -> stageIdToInfo.size
+ )
+ }
+
+ /** If stages is too large, remove and garbage collect old stages */
+ private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
+ if (stages.size > retainedStages) {
+ val toRemove = math.max(retainedStages / 10, 1)
+ stages.take(toRemove).foreach { s =>
+ stageIdToData.remove((s.stageId, s.attemptId))
+ stageIdToInfo.remove(s.stageId)
+ }
+ stages.trimStart(toRemove)
+ }
+ }
+
+ /** If jobs is too large, remove and garbage collect old jobs */
+ private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized {
+ if (jobs.size > retainedJobs) {
+ val toRemove = math.max(retainedJobs / 10, 1)
+ jobs.take(toRemove).foreach { job =>
+ jobIdToData.remove(job.jobId)
+ }
+ jobs.trimStart(toRemove)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) = synchronized {
+ val jobGroup = for (
+ props <- Option(jobStart.properties);
+ group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID))
+ ) yield group
+ val jobData: JobUIData =
+ new JobUIData(
+ jobId = jobStart.jobId,
+ startTime = Some(System.currentTimeMillis),
+ endTime = None,
+ stageIds = jobStart.stageIds,
+ jobGroup = jobGroup,
+ status = JobExecutionStatus.RUNNING)
+ // Compute (a potential underestimate of) the number of tasks that will be run by this job.
+ // This may be an underestimate because the job start event references all of the result
+ // stages's transitive stage dependencies, but some of these stages might be skipped if their
+ // output is available from earlier runs.
+ // See https://github.com/apache/spark/pull/3009 for a more extensive discussion.
+ jobData.numTasks = {
+ val allStages = jobStart.stageInfos
+ val missingStages = allStages.filter(_.completionTime.isEmpty)
+ missingStages.map(_.numTasks).sum
+ }
+ jobIdToData(jobStart.jobId) = jobData
+ activeJobs(jobStart.jobId) = jobData
+ for (stageId <- jobStart.stageIds) {
+ stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId)
+ }
+ // If there's no information for a stage, store the StageInfo received from the scheduler
+ // so that we can display stage descriptions for pending stages:
+ for (stageInfo <- jobStart.stageInfos) {
+ stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo)
+ stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData)
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
+ val jobData = activeJobs.remove(jobEnd.jobId).getOrElse {
+ logWarning(s"Job completed for unknown job ${jobEnd.jobId}")
+ new JobUIData(jobId = jobEnd.jobId)
+ }
+ jobData.endTime = Some(System.currentTimeMillis())
+ jobEnd.jobResult match {
+ case JobSucceeded =>
+ completedJobs += jobData
+ trimJobsIfNecessary(completedJobs)
+ jobData.status = JobExecutionStatus.SUCCEEDED
+ case JobFailed(exception) =>
+ failedJobs += jobData
+ trimJobsIfNecessary(failedJobs)
+ jobData.status = JobExecutionStatus.FAILED
+ }
+ for (stageId <- jobData.stageIds) {
+ stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage =>
+ jobsUsingStage.remove(jobEnd.jobId)
+ stageIdToInfo.get(stageId).foreach { stageInfo =>
+ if (stageInfo.submissionTime.isEmpty) {
+ // if this stage is pending, it won't complete, so mark it as "skipped":
+ skippedStages += stageInfo
+ trimStagesIfNecessary(skippedStages)
+ jobData.numSkippedStages += 1
+ jobData.numSkippedTasks += stageInfo.numTasks
+ }
+ }
+ }
+ }
+ }
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
val stage = stageCompleted.stageInfo
+ stageIdToInfo(stage.stageId) = stage
val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), {
logWarning("Stage completed for unknown stage " + stage.stageId)
new StageUIData
@@ -78,19 +231,25 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
completedStages += stage
- trimIfNecessary(completedStages)
+ numCompletedStages += 1
+ trimStagesIfNecessary(completedStages)
} else {
failedStages += stage
- trimIfNecessary(failedStages)
+ numFailedStages += 1
+ trimStagesIfNecessary(failedStages)
}
- }
- /** If stages is too large, remove and garbage collect old stages */
- private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
- if (stages.size > retainedStages) {
- val toRemove = math.max(retainedStages / 10, 1)
- stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) }
- stages.trimStart(toRemove)
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveStages -= 1
+ if (stage.failureReason.isEmpty) {
+ jobData.completedStageIndices.add(stage.stageId)
+ } else {
+ jobData.numFailedStages += 1
+ }
}
}
@@ -103,6 +262,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
+ stageIdToInfo(stage.stageId) = stage
val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData)
stageData.schedulingPool = poolName
@@ -112,6 +272,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo])
stages(stage.stageId) = stage
+
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveStages += 1
+ }
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
@@ -124,6 +292,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.numActiveTasks += 1
stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo))
}
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveTasks += 1
+ }
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
@@ -181,6 +356,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
taskData.taskInfo = info
taskData.taskMetrics = metrics
taskData.errorMessage = errorMessage
+
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveTasks -= 1
+ taskEnd.reason match {
+ case Success =>
+ jobData.numCompletedTasks += 1
+ case _ =>
+ jobData.numFailedTasks += 1
+ }
+ }
}
}
@@ -214,6 +403,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.inputBytes += inputBytesDelta
execSummary.inputBytes += inputBytesDelta
+ val outputBytesDelta =
+ (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L)
+ - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L))
+ stageData.outputBytes += outputBytesDelta
+ execSummary.outputBytes += outputBytesDelta
+
val diskSpillDelta =
taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L)
stageData.diskBytesSpilled += diskSpillDelta
@@ -277,4 +472,5 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
private object JobProgressListener {
val DEFAULT_POOL_NAME = "default"
val DEFAULT_RETAINED_STAGES = 1000
+ val DEFAULT_RETAINED_JOBS = 1000
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
deleted file mode 100644
index a82f71ed08475..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.jobs
-
-import javax.servlet.http.HttpServletRequest
-
-import scala.xml.{Node, NodeSeq}
-
-import org.apache.spark.scheduler.Schedulable
-import org.apache.spark.ui.{WebUIPage, UIUtils}
-
-/** Page showing list of all ongoing and recently finished stages and pools */
-private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") {
- private val live = parent.live
- private val sc = parent.sc
- private val listener = parent.listener
- private lazy val isFairScheduler = parent.isFairScheduler
-
- def render(request: HttpServletRequest): Seq[Node] = {
- listener.synchronized {
- val activeStages = listener.activeStages.values.toSeq
- val completedStages = listener.completedStages.reverse.toSeq
- val failedStages = listener.failedStages.reverse.toSeq
- val now = System.currentTimeMillis
-
- val activeStagesTable =
- new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
- parent, parent.killEnabled)
- val completedStagesTable =
- new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent)
- val failedStagesTable =
- new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent)
-
- // For now, pool information is only accessible in live UIs
- val pools = if (live) sc.getAllPools else Seq[Schedulable]()
- val poolTable = new PoolTable(pools, parent)
-
- val summary: NodeSeq =
-
-
- {if (live) {
- // Total duration is not meaningful unless the UI is live
-
- Total Duration:
- {UIUtils.formatDuration(now - sc.startTime)}
-
++
- failedStagesTable.toNodeSeq
-
- UIUtils.headerSparkPage("Spark Stages", content, parent)
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
deleted file mode 100644
index c16542c9db30f..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.jobs
-
-import javax.servlet.http.HttpServletRequest
-
-import org.apache.spark.SparkConf
-import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.{SparkUI, SparkUITab}
-
-/** Web UI showing progress status of all jobs in the given SparkContext. */
-private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
- val live = parent.live
- val sc = parent.sc
- val conf = if (live) sc.conf else new SparkConf
- val killEnabled = conf.getBoolean("spark.ui.killEnabled", true)
- val listener = new JobProgressListener(conf)
-
- attachPage(new JobProgressPage(this))
- attachPage(new StagePage(this))
- attachPage(new PoolPage(this))
- parent.registerListener(listener)
-
- def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
-
- def handleKillRequest(request: HttpServletRequest) = {
- if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) {
- val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
- val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
- if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) {
- sc.cancelStage(stageId)
- }
- // Do a quick pause here to give Spark time to kill the stage so it shows up as
- // killed after the refresh. Note that this will block the serving thread so the
- // time should be limited in duration.
- Thread.sleep(100)
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
new file mode 100644
index 0000000000000..b2bbfdee56946
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.ui.{SparkUI, SparkUITab}
+
+/** Web UI showing progress status of all jobs in the given SparkContext. */
+private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
+ val sc = parent.sc
+ val killEnabled = parent.killEnabled
+ def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
+ val listener = parent.jobProgressListener
+
+ attachPage(new AllJobsPage(this))
+ attachPage(new JobPage(this))
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index 7a6c7d1a497ed..5fc6cc7533150 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -25,8 +25,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.{WebUIPage, UIUtils}
/** Page showing specific pool details */
-private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
- private val live = parent.live
+private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
private val sc = parent.sc
private val listener = parent.listener
@@ -38,11 +37,12 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
case Some(s) => s.values.toSeq
case None => Seq[StageInfo]()
}
- val activeStagesTable =
- new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent)
+ val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
// For now, pool information is only accessible in live UIs
- val pools = if (live) Seq(sc.getPoolForName(poolName).get) else Seq[Schedulable]()
+ val pools = sc.map(_.getPoolForName(poolName).get).toSeq
val poolTable = new PoolTable(pools, parent)
val content =
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 64178e1e33d41..df1899e7a9b84 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -24,7 +24,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
-private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
+private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) {
private val listener = parent.listener
def toNodeSeq: Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index db01be596e073..bfa54f8492068 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -22,13 +22,16 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, Unparsed}
+import org.apache.commons.lang3.StringEscapeUtils
+
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
import org.apache.spark.ui.jobs.UIData._
import org.apache.spark.util.{Utils, Distribution}
-import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
/** Page showing statistics and task list for a given stage */
-private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
+private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
@@ -52,12 +55,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val numCompleted = tasks.count(_.taskInfo.finished)
val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables
+ val hasAccumulators = accumulables.size > 0
val hasInput = stageData.inputBytes > 0
+ val hasOutput = stageData.outputBytes > 0
val hasShuffleRead = stageData.shuffleReadBytes > 0
val hasShuffleWrite = stageData.shuffleWriteBytes > 0
val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0
- // scalastyle:off
val summary =
@@ -65,55 +69,125 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
Total task time across all tasks:
{UIUtils.formatDuration(stageData.executorRunTime)}
- {if (hasInput)
+ {if (hasInput) {
+:
+ getFormattedTimeQuantiles(gettingResultTimes)
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- val totalExecutionTime = {
- if (info.gettingResultTime > 0) {
- (info.gettingResultTime - info.launchTime).toDouble
- } else {
- (info.finishTime - info.launchTime).toDouble
- }
- }
- totalExecutionTime - metrics.get.executorRunTime
+ getSchedulerDelay(info, metrics.get).toDouble
}
val schedulerDelayTitle =
- Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
+ // The summary table does not use CSS to stripe rows, which doesn't work with hidden
+ // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows).
+ Some(UIUtils.listingTable(
+ quantileHeaders,
+ identity[Seq[Node]],
+ listings,
+ fixedWidth = true,
+ id = Some("task-summary-table"),
+ stripeRowsWithCss = false))
}
val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
@@ -221,6 +340,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val content =
summary ++
+ showAdditionalMetrics ++
Summary Metrics for {numCompleted} Completed Tasks
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++
Aggregated Metrics by Executor
++ executorTable.toNodeSeq ++
@@ -232,7 +352,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
}
def taskRow(
+ hasAccumulators: Boolean,
hasInput: Boolean,
+ hasOutput: Boolean,
hasShuffleRead: Boolean,
hasShuffleWrite: Boolean,
hasBytesSpilled: Boolean)(taskData: TaskUIData): Seq[Node] = {
@@ -241,8 +363,14 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
else metrics.map(_.executorRunTime).getOrElse(1L)
val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration)
else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("")
+ val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L)
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
+ val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
+ val gettingResultTime = info.gettingResultTime
+
+ val maybeAccumulators = info.accumulables
+ val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"}
val maybeInput = metrics.flatMap(_.inputMetrics)
val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("")
@@ -250,6 +378,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
.map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})")
.getOrElse("")
+ val maybeOutput = metrics.flatMap(_.outputMetrics)
+ val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("")
+ val outputReadable = maybeOutput
+ .map(m => s"${Utils.bytesToString(m.bytesWritten)}")
+ .getOrElse("")
+
val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead)
val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("")
@@ -282,30 +416,45 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
}
++
@@ -43,6 +44,7 @@ private[ui] class StageTableBase(
Duration
Tasks: Succeeded/Total
Input
+
Output
Shuffle Read
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- build.dir
- ${user.dir}/build
-
-
-
- build.dir.hive
- ${build.dir}/hive
-
-
-
- hadoop.tmp.dir
- ${build.dir.hive}/test/hadoop-${user.name}
- A base for other temporary directories.
-
-
-
-
-
- hive.exec.scratchdir
- ${build.dir}/scratchdir
- Scratch space for Hive jobs
-
-
-
- hive.exec.local.scratchdir
- ${build.dir}/localscratchdir/
- Local scratch space for Hive jobs
-
-
-
- javax.jdo.option.ConnectionURL
-
- jdbc:derby:;databaseName=../build/test/junit_metastore_db;create=true
-
-
-
- javax.jdo.option.ConnectionDriverName
- org.apache.derby.jdbc.EmbeddedDriver
-
-
-
- javax.jdo.option.ConnectionUserName
- APP
-
-
-
- javax.jdo.option.ConnectionPassword
- mine
-
-
-
-
- hive.metastore.warehouse.dir
- ${test.warehouse.dir}
-
-
-
-
- hive.metastore.metadb.dir
- ${build.dir}/test/data/metadb/
-
- Required by metastore server or if the uris argument below is not supplied
-
-
-
-
- test.log.dir
- ${build.dir}/test/logs
-
-
-
-
- test.src.dir
- ${build.dir}/src/test
-
-
-
-
-
-
- hive.jar.path
- ${build.dir.hive}/ql/hive-exec-${version}.jar
-
-
-
-
- hive.metastore.rawstore.impl
- org.apache.hadoop.hive.metastore.ObjectStore
- Name of the class that implements org.apache.hadoop.hive.metastore.rawstore interface. This class is used to store and retrieval of raw metadata objects such as table, database
-
-
-
- hive.querylog.location
- ${build.dir}/tmp
- Location of the structured hive logs
-
-
-
-
-
- hive.task.progress
- false
- Track progress of a task
-
-
-
- hive.support.concurrency
- false
- Whether hive supports concurrency or not. A zookeeper instance must be up and running for the default hive lock manager to support read-write locks.
-
-
-
- fs.pfile.impl
- org.apache.hadoop.fs.ProxyLocalFileSystem
- A proxy for local file system used for cross file system testing
-
-
-
- hive.exec.mode.local.auto
- false
-
- Let hive determine whether to run in local mode automatically
- Disabling this for tests so that minimr is not affected
-
-
-
-
- hive.auto.convert.join
- false
- Whether Hive enable the optimization about converting common join into mapjoin based on the input file size
-
-
-
- hive.ignore.mapjoin.hint
- false
- Whether Hive ignores the mapjoin hint
-
-
-
- hive.input.format
- org.apache.hadoop.hive.ql.io.CombineHiveInputFormat
- The default input format, if it is not specified, the system assigns it. It is set to HiveInputFormat for hadoop versions 17, 18 and 19, whereas it is set to CombineHiveInputFormat for hadoop 20. The user can always overwrite it - if there is a bug in CombineHiveInputFormat, it can always be manually set to HiveInputFormat.
-
-
-
- hive.default.rcfile.serde
- org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe
- The default SerDe hive will use for the rcfile format
-
-
-
diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh
new file mode 100755
index 0000000000000..7473c20d28e09
--- /dev/null
+++ b/dev/change-version-to-2.10.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+find . -name 'pom.xml' | grep -v target \
+ | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {}
diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh
new file mode 100755
index 0000000000000..3957a9f3ba258
--- /dev/null
+++ b/dev/change-version-to-2.11.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+find . -name 'pom.xml' | grep -v target \
+ | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.10|\1_2.11|g' {}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 281e8d4de6d71..e0aca467ac949 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -27,13 +27,20 @@
# Would be nice to add:
# - Send output to stderr and have useful logging in stdout
-GIT_USERNAME=${GIT_USERNAME:-pwendell}
-GIT_PASSWORD=${GIT_PASSWORD:-XXX}
+# Note: The following variables must be set before use!
+ASF_USERNAME=${ASF_USERNAME:-pwendell}
+ASF_PASSWORD=${ASF_PASSWORD:-XXX}
GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX}
GIT_BRANCH=${GIT_BRANCH:-branch-1.0}
-RELEASE_VERSION=${RELEASE_VERSION:-1.0.0}
+RELEASE_VERSION=${RELEASE_VERSION:-1.2.0}
+NEXT_VERSION=${NEXT_VERSION:-1.2.1}
RC_NAME=${RC_NAME:-rc2}
-USER_NAME=${USER_NAME:-pwendell}
+
+M2_REPO=~/.m2/repository
+SPARK_REPO=$M2_REPO/org/apache/spark
+NEXUS_ROOT=https://repository.apache.org/service/local/staging
+NEXUS_UPLOAD=$NEXUS_ROOT/deploy/maven2
+NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads
if [ -z "$JAVA_HOME" ]; then
echo "Error: JAVA_HOME is not set, cannot proceed."
@@ -46,31 +53,90 @@ set -e
GIT_TAG=v$RELEASE_VERSION-$RC_NAME
if [[ ! "$@" =~ --package-only ]]; then
- echo "Creating and publishing release"
+ echo "Creating release commit and publishing to Apache repository"
# Artifact publishing
- git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH
- cd spark
+ git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \
+ -b $GIT_BRANCH
+ pushd spark
export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g"
- mvn -Pyarn release:clean
-
- mvn -DskipTests \
- -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
- -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
- -Dmaven.javadoc.skip=true \
- -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Dtag=$GIT_TAG -DautoVersionSubmodules=true \
- -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
- --batch-mode release:prepare
-
- mvn -DskipTests \
- -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
- -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Dmaven.javadoc.skip=true \
+ # Create release commits and push them to github
+ # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build
+ # or before we coin the release commit. This helps avoid races where
+ # other people add commits to this branch while we are in the middle of building.
+ old=" ${RELEASE_VERSION}-SNAPSHOT<\/version>"
+ new=" ${RELEASE_VERSION}<\/version>"
+ find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \
+ -e "s/$old/$new/" {}
+ git commit -a -m "Preparing Spark release $GIT_TAG"
+ echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH"
+ git tag $GIT_TAG
+
+ old=" ${RELEASE_VERSION}<\/version>"
+ new=" ${NEXT_VERSION}-SNAPSHOT<\/version>"
+ find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \
+ -e "s/$old/$new/" {}
+ git commit -a -m "Preparing development version ${NEXT_VERSION}-SNAPSHOT"
+ git push origin $GIT_TAG
+ git push origin HEAD:$GIT_BRANCH
+ git checkout -f $GIT_TAG
+
+ # Using Nexus API documented here:
+ # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API
+ echo "Creating Nexus staging repository"
+ repo_request="Apache Spark $GIT_TAG"
+ out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \
+ -H "Content-Type:application/xml" -v \
+ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start)
+ staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/")
+ echo "Created Nexus staging repository: $staged_repo_id"
+
+ rm -rf $SPARK_REPO
+
+ mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
- release:perform
+ clean install
- cd ..
+ ./dev/change-version-to-2.11.sh
+
+ mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
+ -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
+ clean install
+
+ ./dev/change-version-to-2.10.sh
+
+ pushd $SPARK_REPO
+
+ # Remove any extra files generated during install
+ find . -type f |grep -v \.jar |grep -v \.pom | xargs rm
+
+ echo "Creating hash and signature files"
+ for file in $(find . -type f)
+ do
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file;
+ gpg --print-md MD5 $file > $file.md5;
+ gpg --print-md SHA1 $file > $file.sha1
+ done
+
+ echo "Uplading files to $NEXUS_UPLOAD"
+ for file in $(find . -type f)
+ do
+ # strip leading ./
+ file_short=$(echo $file | sed -e "s/\.\///")
+ dest_url="$NEXUS_UPLOAD/org/apache/spark/$file_short"
+ echo " Uploading $file_short"
+ curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url
+ done
+
+ echo "Closing nexus staging repository"
+ repo_request="$staged_repo_idApache Spark $GIT_TAG"
+ out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \
+ -H "Content-Type:application/xml" -v \
+ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish)
+ echo "Closed Nexus staging repository: $staged_repo_id"
+
+ popd
+ popd
rm -rf spark
fi
@@ -101,7 +167,13 @@ make_binary_release() {
cp -r spark spark-$RELEASE_VERSION-bin-$NAME
cd spark-$RELEASE_VERSION-bin-$NAME
- ./make-distribution.sh --name $NAME --tgz $FLAGS
+
+ # TODO There should probably be a flag to make-distribution to allow 2.11 support
+ if [[ $FLAGS == *scala-2.11* ]]; then
+ ./dev/change-version-to-2.11.sh
+ fi
+
+ ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log
cd ..
cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz .
rm -rf spark-$RELEASE_VERSION-bin-$NAME
@@ -117,22 +189,24 @@ make_binary_release() {
spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
}
-make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" &
-make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
-make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" &
-make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" &
+
+make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" &
+make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" &
+make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
+make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" &
+make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" &
+make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" &
+make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" &
make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" &
-make_binary_release "mapr3" "-Pmapr3 -Phive" &
-make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" &
wait
# Copy data
echo "Copying release tarballs"
rc_folder=spark-$RELEASE_VERSION-$RC_NAME
-ssh $USER_NAME@people.apache.org \
- mkdir /home/$USER_NAME/public_html/$rc_folder
+ssh $ASF_USERNAME@people.apache.org \
+ mkdir /home/$ASF_USERNAME/public_html/$rc_folder
scp spark-* \
- $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/
+ $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/
# Docs
cd spark
@@ -142,12 +216,12 @@ cd docs
JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build
echo "Copying release documentation"
rc_docs_folder=${rc_folder}-docs
-ssh $USER_NAME@people.apache.org \
- mkdir /home/$USER_NAME/public_html/$rc_docs_folder
-rsync -r _site/* $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_docs_folder
+ssh $ASF_USERNAME@people.apache.org \
+ mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder
+rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder
echo "Release $RELEASE_VERSION completed:"
echo "Git tag:\t $GIT_TAG"
echo "Release commit:\t $release_hash"
-echo "Binary location:\t http://people.apache.org/~$USER_NAME/$rc_folder"
-echo "Doc location:\t http://people.apache.org/~$USER_NAME/$rc_docs_folder"
+echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder"
+echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder"
diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py
new file mode 100755
index 0000000000000..f4bf734081583
--- /dev/null
+++ b/dev/create-release/generate-contributors.py
@@ -0,0 +1,206 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This script automates the process of creating release notes.
+
+import os
+import re
+import sys
+
+from releaseutils import *
+
+# You must set the following before use!
+JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira")
+START_COMMIT = os.environ.get("START_COMMIT", "37b100")
+END_COMMIT = os.environ.get("END_COMMIT", "3693ae")
+
+try:
+ from jira.client import JIRA
+except ImportError:
+ print "This tool requires the jira-python library"
+ print "Install using 'sudo pip install jira-python'"
+ sys.exit(-1)
+
+try:
+ import unidecode
+except ImportError:
+ print "This tool requires the unidecode library to decode obscure github usernames"
+ print "Install using 'sudo pip install unidecode'"
+ sys.exit(-1)
+
+# If commit range is not specified, prompt the user to provide it
+if not START_COMMIT or not END_COMMIT:
+ print "A commit range is required to proceed."
+ if not START_COMMIT:
+ START_COMMIT = raw_input("Please specify starting commit hash (inclusive): ")
+ if not END_COMMIT:
+ END_COMMIT = raw_input("Please specify ending commit hash (non-inclusive): ")
+
+# Verify provided arguments
+start_commit_line = get_one_line(START_COMMIT)
+end_commit_line = get_one_line(END_COMMIT)
+num_commits = num_commits_in_range(START_COMMIT, END_COMMIT)
+if not start_commit_line: sys.exit("Start commit %s not found!" % START_COMMIT)
+if not end_commit_line: sys.exit("End commit %s not found!" % END_COMMIT)
+if num_commits == 0:
+ sys.exit("There are no commits in the provided range [%s, %s)" % (START_COMMIT, END_COMMIT))
+print "\n=================================================================================="
+print "JIRA server: %s" % JIRA_API_BASE
+print "Start commit (inclusive): %s" % start_commit_line
+print "End commit (non-inclusive): %s" % end_commit_line
+print "Number of commits in this range: %s" % num_commits
+print
+response = raw_input("Is this correct? [Y/n] ")
+if response.lower() != "y" and response:
+ sys.exit("Ok, exiting")
+print "==================================================================================\n"
+
+# Find all commits within this range
+print "Gathering commits within range [%s..%s)" % (START_COMMIT, END_COMMIT)
+commits = get_one_line_commits(START_COMMIT, END_COMMIT)
+if not commits: sys.exit("Error: No commits found within this range!")
+commits = commits.split("\n")
+
+# Filter out special commits
+releases = []
+reverts = []
+nojiras = []
+filtered_commits = []
+def is_release(commit):
+ return re.findall("\[release\]", commit.lower()) or\
+ "maven-release-plugin" in commit or "CHANGES.txt" in commit
+def has_no_jira(commit):
+ return not re.findall("SPARK-[0-9]+", commit.upper())
+def is_revert(commit):
+ return "revert" in commit.lower()
+def is_docs(commit):
+ return re.findall("docs*", commit.lower()) or "programming guide" in commit.lower()
+for c in commits:
+ if not c: continue
+ elif is_release(c): releases.append(c)
+ elif is_revert(c): reverts.append(c)
+ elif is_docs(c): filtered_commits.append(c) # docs may not have JIRA numbers
+ elif has_no_jira(c): nojiras.append(c)
+ else: filtered_commits.append(c)
+
+# Warn against ignored commits
+def print_indented(_list):
+ for x in _list: print " %s" % x
+if releases or reverts or nojiras:
+ print "\n=================================================================================="
+ if releases: print "Releases (%d)" % len(releases); print_indented(releases)
+ if reverts: print "Reverts (%d)" % len(reverts); print_indented(reverts)
+ if nojiras: print "No JIRA (%d)" % len(nojiras); print_indented(nojiras)
+ print "==================== Warning: the above commits will be ignored ==================\n"
+response = raw_input("%d commits left to process. Ok to proceed? [y/N] " % len(filtered_commits))
+if response.lower() != "y":
+ sys.exit("Ok, exiting.")
+
+# Keep track of warnings to tell the user at the end
+warnings = []
+
+# Populate a map that groups issues and components by author
+# It takes the form: Author name -> { Contribution type -> Spark components }
+# For instance,
+# {
+# 'Andrew Or': {
+# 'bug fixes': ['windows', 'core', 'web ui'],
+# 'improvements': ['core']
+# },
+# 'Tathagata Das' : {
+# 'bug fixes': ['streaming']
+# 'new feature': ['streaming']
+# }
+# }
+#
+author_info = {}
+jira_options = { "server": JIRA_API_BASE }
+jira = JIRA(jira_options)
+print "\n=========================== Compiling contributor list ==========================="
+for commit in filtered_commits:
+ commit_hash = re.findall("^[a-z0-9]+", commit)[0]
+ issues = re.findall("SPARK-[0-9]+", commit.upper())
+ author = get_author(commit_hash)
+ author = unidecode.unidecode(unicode(author, "UTF-8")) # guard against special characters
+ date = get_date(commit_hash)
+ # Parse components from the commit message, if any
+ commit_components = find_components(commit, commit_hash)
+ # Populate or merge an issue into author_info[author]
+ def populate(issue_type, components):
+ components = components or [CORE_COMPONENT] # assume core if no components provided
+ if author not in author_info:
+ author_info[author] = {}
+ if issue_type not in author_info[author]:
+ author_info[author][issue_type] = set()
+ for component in all_components:
+ author_info[author][issue_type].add(component)
+ # Find issues and components associated with this commit
+ for issue in issues:
+ jira_issue = jira.issue(issue)
+ jira_type = jira_issue.fields.issuetype.name
+ jira_type = translate_issue_type(jira_type, issue, warnings)
+ jira_components = [translate_component(c.name, commit_hash, warnings)\
+ for c in jira_issue.fields.components]
+ all_components = set(jira_components + commit_components)
+ populate(jira_type, all_components)
+ # For docs without an associated JIRA, manually add it ourselves
+ if is_docs(commit) and not issues:
+ populate("documentation", commit_components)
+ print " Processed commit %s authored by %s on %s" % (commit_hash, author, date)
+print "==================================================================================\n"
+
+# Write to contributors file ordered by author names
+# Each line takes the format "Author name - semi-colon delimited contributions"
+# e.g. Andrew Or - Bug fixes in Windows, Core, and Web UI; improvements in Core
+# e.g. Tathagata Das - Bug fixes and new features in Streaming
+contributors_file_name = "contributors.txt"
+contributors_file = open(contributors_file_name, "w")
+authors = author_info.keys()
+authors.sort()
+for author in authors:
+ contribution = ""
+ components = set()
+ issue_types = set()
+ for issue_type, comps in author_info[author].items():
+ components.update(comps)
+ issue_types.add(issue_type)
+ # If there is only one component, mention it only once
+ # e.g. Bug fixes, improvements in MLlib
+ if len(components) == 1:
+ contribution = "%s in %s" % (nice_join(issue_types), next(iter(components)))
+ # Otherwise, group contributions by issue types instead of modules
+ # e.g. Bug fixes in MLlib, Core, and Streaming; documentation in YARN
+ else:
+ contributions = ["%s in %s" % (issue_type, nice_join(comps)) \
+ for issue_type, comps in author_info[author].items()]
+ contribution = "; ".join(contributions)
+ # Do not use python's capitalize() on the whole string to preserve case
+ assert contribution
+ contribution = contribution[0].capitalize() + contribution[1:]
+ line = "%s - %s" % (author, contribution)
+ contributors_file.write(line + "\n")
+contributors_file.close()
+print "Contributors list is successfully written to %s!" % contributors_file_name
+
+# Log any warnings encountered in the process
+if warnings:
+ print "\n============ Warnings encountered while creating the contributor list ============"
+ for w in warnings: print w
+ print "Please correct these in the final contributors list at %s." % contributors_file_name
+ print "==================================================================================\n"
+
diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py
new file mode 100755
index 0000000000000..e56d7fa58fa2c
--- /dev/null
+++ b/dev/create-release/releaseutils.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This file contains helper methods used in creating a release.
+
+import re
+from subprocess import Popen, PIPE
+
+# Utility functions run git commands (written with Git 1.8.5)
+def run_cmd(cmd): return Popen(cmd, stdout=PIPE).communicate()[0]
+def get_author(commit_hash):
+ return run_cmd(["git", "show", "--quiet", "--pretty=format:%an", commit_hash])
+def get_date(commit_hash):
+ return run_cmd(["git", "show", "--quiet", "--pretty=format:%cd", commit_hash])
+def get_one_line(commit_hash):
+ return run_cmd(["git", "show", "--quiet", "--pretty=format:\"%h %cd %s\"", commit_hash])
+def get_one_line_commits(start_hash, end_hash):
+ return run_cmd(["git", "log", "--oneline", "%s..%s" % (start_hash, end_hash)])
+def num_commits_in_range(start_hash, end_hash):
+ output = run_cmd(["git", "log", "--oneline", "%s..%s" % (start_hash, end_hash)])
+ lines = [line for line in output.split("\n") if line] # filter out empty lines
+ return len(lines)
+
+# Maintain a mapping for translating issue types to contributions in the release notes
+# This serves an additional function of warning the user against unknown issue types
+# Note: This list is partially derived from this link:
+# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/issuetypes
+# Keep these in lower case
+known_issue_types = {
+ "bug": "bug fixes",
+ "build": "build fixes",
+ "improvement": "improvements",
+ "new feature": "new features",
+ "documentation": "documentation"
+}
+
+# Maintain a mapping for translating component names when creating the release notes
+# This serves an additional function of warning the user against unknown components
+# Note: This list is largely derived from this link:
+# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/components
+CORE_COMPONENT = "Core"
+known_components = {
+ "block manager": CORE_COMPONENT,
+ "build": CORE_COMPONENT,
+ "deploy": CORE_COMPONENT,
+ "documentation": CORE_COMPONENT,
+ "ec2": "EC2",
+ "examples": CORE_COMPONENT,
+ "graphx": "GraphX",
+ "input/output": CORE_COMPONENT,
+ "java api": "Java API",
+ "mesos": "Mesos",
+ "ml": "MLlib",
+ "mllib": "MLlib",
+ "project infra": "Project Infra",
+ "pyspark": "PySpark",
+ "shuffle": "Shuffle",
+ "spark core": CORE_COMPONENT,
+ "spark shell": CORE_COMPONENT,
+ "sql": "SQL",
+ "streaming": "Streaming",
+ "web ui": "Web UI",
+ "windows": "Windows",
+ "yarn": "YARN"
+}
+
+# Translate issue types using a format appropriate for writing contributions
+# If an unknown issue type is encountered, warn the user
+def translate_issue_type(issue_type, issue_id, warnings):
+ issue_type = issue_type.lower()
+ if issue_type in known_issue_types:
+ return known_issue_types[issue_type]
+ else:
+ warnings.append("Unknown issue type \"%s\" (see %s)" % (issue_type, issue_id))
+ return issue_type
+
+# Translate component names using a format appropriate for writing contributions
+# If an unknown component is encountered, warn the user
+def translate_component(component, commit_hash, warnings):
+ component = component.lower()
+ if component in known_components:
+ return known_components[component]
+ else:
+ warnings.append("Unknown component \"%s\" (see %s)" % (component, commit_hash))
+ return component
+
+# Parse components in the commit message
+# The returned components are already filtered and translated
+def find_components(commit, commit_hash):
+ components = re.findall("\[\w*\]", commit.lower())
+ components = [translate_component(c, commit_hash)\
+ for c in components if c in known_components]
+ return components
+
+# Join a list of strings in a human-readable manner
+# e.g. ["Juice"] -> "Juice"
+# e.g. ["Juice", "baby"] -> "Juice and baby"
+# e.g. ["Juice", "baby", "moon"] -> "Juice, baby, and moon"
+def nice_join(str_list):
+ str_list = list(str_list) # sometimes it's a set
+ if not str_list:
+ return ""
+ elif len(str_list) == 1:
+ return next(iter(str_list))
+ elif len(str_list) == 2:
+ return " and ".join(str_list)
+ else:
+ return ", ".join(str_list[:-1]) + ", and " + str_list[-1]
+
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index a8e92e36fe0d8..02ac20984add9 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -73,11 +73,10 @@ def fail(msg):
def run_cmd(cmd):
+ print cmd
if isinstance(cmd, list):
- print " ".join(cmd)
return subprocess.check_output(cmd)
else:
- print cmd
return subprocess.check_output(cmd.split(" "))
diff --git a/dev/run-tests b/dev/run-tests
index c3d8f49cdd993..328a73bd8b26d 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -24,6 +24,16 @@ cd "$FWDIR"
# Remove work directory
rm -rf ./work
+source "$FWDIR/dev/run-tests-codes.sh"
+
+CURRENT_BLOCK=$BLOCK_GENERAL
+
+function handle_error () {
+ echo "[error] Got a return code of $? on line $1 of the run-tests script."
+ exit $CURRENT_BLOCK
+}
+
+
# Build against the right verison of Hadoop.
{
if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then
@@ -32,7 +42,7 @@ rm -rf ./work
elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then
export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1"
elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then
- export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0"
+ export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0"
elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then
export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0"
fi
@@ -91,26 +101,34 @@ if [ -n "$AMPLAB_JENKINS" ]; then
fi
fi
-# Fail fast
-set -e
set -o pipefail
+trap 'handle_error $LINENO' ERR
echo ""
echo "========================================================================="
echo "Running Apache RAT checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_RAT
+
./dev/check-license
echo ""
echo "========================================================================="
echo "Running Scala style checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_SCALA_STYLE
+
./dev/lint-scala
echo ""
echo "========================================================================="
echo "Running Python style checks"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_PYTHON_STYLE
+
./dev/lint-python
echo ""
@@ -118,21 +136,29 @@ echo "========================================================================="
echo "Building Spark"
echo "========================================================================="
-{
- # We always build with Hive because the PySpark Spark SQL tests need it.
- BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
+CURRENT_BLOCK=$BLOCK_BUILD
- echo "[info] Building Spark with these arguments: $BUILD_MVN_PROFILE_ARGS"
+{
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
#+ (either resolution or compilation) prompts the user for input either q, r, etc
#+ to quit or retry. This echo is there to make it not block.
- # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
+ # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
#+ single argument!
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
+ # First build with 0.12 to ensure patches do not break the hive 12 build
+ HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0"
+ echo "[info] Compile with hive 0.12"
+ echo -e "q\n" \
+ | sbt/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \
+ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
+
+ # Then build with default version(0.13.1) because tests are based on this version
+ echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\
+ " -Phive -Phive-thriftserver"
echo -e "q\n" \
- | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly \
+ | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \
| grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
}
@@ -141,17 +167,19 @@ echo "========================================================================="
echo "Running Spark unit tests"
echo "========================================================================="
+CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
+
{
# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled.
# This must be a single argument, as it is.
if [ -n "$_RUN_SQL_TESTS" ]; then
- SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
+ SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver"
fi
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
#+ will be interpreted as a single test, which doesn't work.
- SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test")
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
@@ -175,10 +203,16 @@ echo ""
echo "========================================================================="
echo "Running PySpark tests"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
+
./python/run-tests
echo ""
echo "========================================================================="
echo "Detecting binary incompatibilites with MiMa"
echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_MIMA
+
./dev/mima
diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh
new file mode 100644
index 0000000000000..1348e0609dda4
--- /dev/null
+++ b/dev/run-tests-codes.sh
@@ -0,0 +1,27 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+readonly BLOCK_GENERAL=10
+readonly BLOCK_RAT=11
+readonly BLOCK_SCALA_STYLE=12
+readonly BLOCK_PYTHON_STYLE=13
+readonly BLOCK_BUILD=14
+readonly BLOCK_SPARK_UNIT_TESTS=15
+readonly BLOCK_PYSPARK_UNIT_TESTS=16
+readonly BLOCK_MIMA=17
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 0b1e31b9413cf..6a849e4f77207 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -26,9 +26,23 @@
FWDIR="$(cd `dirname $0`/..; pwd)"
cd "$FWDIR"
+source "$FWDIR/dev/run-tests-codes.sh"
+
COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments"
PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId"
+# Important Environment Variables
+# ---
+# $ghprbActualCommit
+#+ This is the hash of the most recent commit in the PR.
+#+ The merge-base of this and master is the commit from which the PR was branched.
+# $sha1
+#+ If the patch merges cleanly, this is a reference to the merge commit hash
+#+ (e.g. "origin/pr/2606/merge").
+#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit.
+#+ The merge-base of this and master in the case of a clean merge is the most recent commit
+#+ against master.
+
COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}"
# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :(
SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}"
@@ -39,9 +53,9 @@ function post_message () {
local message=$1
local data="{\"body\": \"$message\"}"
local HTTP_CODE_HEADER="HTTP Response Code: "
-
+
echo "Attempting to post to Github..."
-
+
local curl_output=$(
curl `#--dump-header -` \
--silent \
@@ -61,12 +75,12 @@ function post_message () {
echo " > data: ${data}" >&2
# exit $curl_status
fi
-
+
local api_response=$(
echo "${curl_output}" \
| grep -v -e "^${HTTP_CODE_HEADER}"
)
-
+
local http_code=$(
echo "${curl_output}" \
| grep -e "^${HTTP_CODE_HEADER}" \
@@ -78,60 +92,97 @@ function post_message () {
echo " > api_response: ${api_response}" >&2
echo " > data: ${data}" >&2
fi
-
+
if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then
echo " > Post successful."
fi
}
+function send_archived_logs () {
+ echo "Archiving unit tests logs..."
+
+ local log_files=$(
+ find .\
+ -name "unit-tests.log" -o\
+ -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\
+ -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\
+ -path "./sql/hive/target/HiveCompatibilitySuite.wrong"
+ )
+
+ if [ -z "$log_files" ]; then
+ echo "> No log files found." >&2
+ else
+ local log_archive="unit-tests-logs.tar.gz"
+ echo "$log_files" | xargs tar czf ${log_archive}
+
+ local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER}
+ local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive})
+ local scp_status="$?"
+
+ if [ "$scp_status" -ne 0 ]; then
+ echo "Failed to send archived unit tests logs to Jenkins master." >&2
+ echo "> scp_status: ${scp_status}" >&2
+ echo "> scp_output: ${scp_output}" >&2
+ else
+ echo "> Send successful."
+ fi
+
+ rm -f ${log_archive}
+ fi
+}
+
+
+# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR
+#+ and not anything else added to master since the PR was branched.
+
# check PR merge-ability and check for new public classes
{
if [ "$sha1" == "$ghprbActualCommit" ]; then
- merge_note=" * This patch **does not** merge cleanly!"
+ merge_note=" * This patch **does not merge cleanly**."
else
merge_note=" * This patch merges cleanly."
+ fi
- source_files=$(
- git diff master... --name-only `# diff patch against master from branch point` \
- | grep -v -e "\/test" `# ignore files in test directories` \
- | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
- | tr "\n" " "
- )
- new_public_classes=$(
- git diff master... ${source_files} `# diff patch against master from branch point` \
- | grep "^\+" `# filter in only added lines` \
- | sed -r -e "s/^\+//g" `# remove the leading +` \
- | grep -e "trait " -e "class " `# filter in lines with these key words` \
- | grep -e "{" -e "(" `# filter in lines with these key words, too` \
- | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
- | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
- | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
- | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
- | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
- | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
- | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
- | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
- | tr -d "\n" `# remove actual LF characters`
- )
-
- if [ "$new_public_classes" == "" ]; then
- public_classes_note=" * This patch adds no public classes."
- else
- public_classes_note=" * This patch adds the following public classes _(experimental)_:"
- public_classes_note="${public_classes_note}\n${new_public_classes}"
- fi
+ source_files=$(
+ git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \
+ | grep -v -e "\/test" `# ignore files in test directories` \
+ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \
+ | tr "\n" " "
+ )
+ new_public_classes=$(
+ git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \
+ | grep "^\+" `# filter in only added lines` \
+ | sed -r -e "s/^\+//g" `# remove the leading +` \
+ | grep -e "trait " -e "class " `# filter in lines with these key words` \
+ | grep -e "{" -e "(" `# filter in lines with these key words, too` \
+ | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \
+ | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \
+ | sed -r -e "s/\{.*//g" `# remove from the { onwards` \
+ | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \
+ | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \
+ | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \
+ | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \
+ | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \
+ | tr -d "\n" `# remove actual LF characters`
+ )
+
+ if [ -z "$new_public_classes" ]; then
+ public_classes_note=" * This patch adds no public classes."
+ else
+ public_classes_note=" * This patch adds the following public classes _(experimental)_:"
+ public_classes_note="${public_classes_note}\n${new_public_classes}"
fi
}
# post start message
{
start_message="\
- [QA tests have started](${BUILD_URL}consoleFull) for \
+ [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
-
+
start_message="${start_message}\n${merge_note}"
# start_message="${start_message}\n${public_classes_note}"
-
+
post_message "$start_message"
}
@@ -141,25 +192,45 @@ function post_message () {
test_result="$?"
if [ "$test_result" -eq "124" ]; then
- fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** \
+ fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \
for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \
after a configured wait of \`${TESTS_TIMEOUT}\`."
post_message "$fail_message"
exit $test_result
+ elif [ "$test_result" -eq "0" ]; then
+ test_result_note=" * This patch **passes all tests**."
else
- if [ "$test_result" -eq "0" ]; then
- test_result_note=" * This patch **passes** unit tests."
+ if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then
+ failing_test="some tests"
+ elif [ "$test_result" -eq "$BLOCK_RAT" ]; then
+ failing_test="RAT tests"
+ elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then
+ failing_test="Scala style tests"
+ elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then
+ failing_test="Python style tests"
+ elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then
+ failing_test="to build"
+ elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then
+ failing_test="Spark unit tests"
+ elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then
+ failing_test="PySpark unit tests"
+ elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then
+ failing_test="MiMa tests"
else
- test_result_note=" * This patch **fails** unit tests."
+ failing_test="some tests"
fi
+
+ test_result_note=" * This patch **fails $failing_test**."
fi
+
+ send_archived_logs
}
# post end message
{
result_message="\
- [QA tests have finished](${BUILD_URL}consoleFull) for \
+ [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
result_message="${result_message}\n${test_result_note}"
diff --git a/dev/scalastyle b/dev/scalastyle
index efb5f291ea3b7..c3c6012e74ffa 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,7 +17,7 @@
# limitations under the License.
#
-echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
+echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt
# Check style with YARN alpha built too
echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
>> scalastyle.txt
@@ -25,7 +25,9 @@ echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-
echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \
>> scalastyle.txt
-ERRORS=$(cat scalastyle.txt | grep -e "\")
+ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}')
+rm scalastyle.txt
+
if test ! -z "$ERRORS"; then
echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS"
exit 1
diff --git a/docs/README.md b/docs/README.md
index 79708c3df9106..119484038083f 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -25,8 +25,7 @@ installing via the Ruby Gem dependency manager. Since the exact HTML output
varies between versions of Jekyll and its dependencies, we list specific versions here
in some cases:
- $ sudo gem install jekyll -v 1.4.3
- $ sudo gem uninstall kramdown -v 1.4.1
+ $ sudo gem install jekyll
$ sudo gem install jekyll-redirect-from
Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory
@@ -44,7 +43,7 @@ You can modify the default Jekyll build as follows:
## Pygments
We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages,
-so you will also need to install that (it requires Python) by running `sudo easy_install Pygments`.
+so you will also need to install that (it requires Python) by running `sudo pip install Pygments`.
To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile
phase, use the following sytax:
@@ -54,19 +53,24 @@ phase, use the following sytax:
// supported languages too.
{% endhighlight %}
-## API Docs (Scaladoc and Epydoc)
+## Sphinx
+
+We use Sphinx to generate Python API docs, so you will need to install it by running
+`sudo pip install sphinx`.
+
+## API Docs (Scaladoc and Sphinx)
You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory.
-Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the
-SPARK_PROJECT_ROOT/pyspark directory. Documentation is only generated for classes that are listed as
+Similarly, you can build just the PySpark docs by running `make html` from the
+SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as
public in `__init__.py`.
When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various
Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a
jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it
may take some time as it generates all of the scaladoc. The jekyll plugin also generates the
-PySpark docs using [epydoc](http://epydoc.sourceforge.net/).
+PySpark docs [Sphinx](http://sphinx-doc.org/).
NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1
jekyll`.
diff --git a/docs/_config.yml b/docs/_config.yml
index 7bc3a78e2d265..a96a76dd9ab5e 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -8,10 +8,13 @@ gems:
kramdown:
entity_output: numeric
-# These allow the documentation to be updated with nerw releases
+include:
+ - _static
+
+# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 1.0.0-SNAPSHOT
-SPARK_VERSION_SHORT: 1.0.0
+SPARK_VERSION: 1.3.0-SNAPSHOT
+SPARK_VERSION_SHORT: 1.3.0
SCALA_BINARY_VERSION: "2.10"
SCALA_VERSION: "2.10.4"
MESOS_VERSION: 0.18.1
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 3b02e090aec28..4566a2fff562b 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -63,19 +63,20 @@
puts "cp -r " + source + "/. " + dest
cp_r(source + "/.", dest)
- # Build Epydoc for Python
- puts "Moving to python directory and building epydoc."
- cd("../python")
- puts `epydoc --config epydoc.conf`
+ # Build Sphinx docs for Python
- puts "Moving back into docs dir."
- cd("../docs")
+ puts "Moving to python/docs directory and building sphinx."
+ cd("../python/docs")
+ puts `make html`
+
+ puts "Moving back into home dir."
+ cd("../../")
puts "Making directory api/python"
- mkdir_p "api/python"
+ mkdir_p "docs/api/python"
- puts "cp -r ../python/docs/. api/python"
- cp_r("../python/docs/.", "api/python")
+ puts "cp -r python/docs/_build/html/. docs/api/python"
+ cp_r("python/docs/_build/html/.", "docs/api/python")
cd("..")
end
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 901c157162fee..40a47410e683a 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -67,11 +67,13 @@ For Apache Hadoop 2.x, 0.23.x, Cloudera CDH, and other Hadoop versions with YARN
YARN version
Profile required
-
0.23.x to 2.1.x
yarn-alpha
+
0.23.x to 2.1.x
yarn-alpha (Deprecated.)
2.2.x and later
yarn
+Note: Support for YARN-alpha API's will be removed in Spark 1.3 (see SPARK-3445).
+
Examples:
{% highlight bash %}
@@ -90,8 +92,11 @@ mvn -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package
# Apache Hadoop 2.3.X
mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package
-# Apache Hadoop 2.4.X
-mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package
+# Apache Hadoop 2.4.X or 2.5.X
+mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package
+
+Versions of Hadoop after 2.5.X may or may not work with the -Phadoop-2.4 profile (they were
+released after this version of Spark).
# Different versions of HDFS and YARN.
mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package
@@ -99,20 +104,34 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski
# Building With Hive and JDBC Support
To enable Hive integration for Spark SQL along with its JDBC server and CLI,
-add the `-Phive` profile to your existing build options.
+add the `-Phive` and `Phive-thriftserver` profiles to your existing build options.
+By default Spark will build with Hive 0.13.1 bindings. You can also build for
+Hive 0.12.0 using the `-Phive-0.12.0` profile.
{% highlight bash %}
-# Apache Hadoop 2.4.X with Hive support
-mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package
+# Apache Hadoop 2.4.X with Hive 13 support
+mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package
+
+# Apache Hadoop 2.4.X with Hive 12 support
+mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-0.12.0 -Phive-thriftserver -DskipTests clean package
{% endhighlight %}
+# Building for Scala 2.11
+To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property:
+
+ mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package
+
+Scala 2.11 support in Spark is experimental and does not support a few features.
+Specifically, Spark's external Kafka library and JDBC component are not yet
+supported in Scala 2.11 builds.
+
# Spark Tests in Maven
Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin).
Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence:
- mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package
- mvn -Pyarn -Phadoop-2.3 -Phive test
+ mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package
+ mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test
The ScalaTest plugin also supports running only a specific test suite as follows:
@@ -171,10 +190,25 @@ can be set to control the SBT build. For example:
sbt/sbt -Pyarn -Phadoop-2.3 assembly
+# Testing with SBT
+
+Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test
+
+To run only a specific test suite as follows:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite"
+
+To run test suites of a specific sub project as follows:
+
+ sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test
+
# Speeding up Compilation with Zinc
[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental
compiler. When run locally as a background process, it speeds up builds of Scala-based projects
like Spark. Developers who regularly recompile Spark with Maven will be the most interested in
Zinc. The project site gives instructions for building and running `zinc`; OS X users can
-install it using `brew install zinc`.
\ No newline at end of file
+install it using `brew install zinc`.
diff --git a/docs/configuration.md b/docs/configuration.md
index 1c33855365170..0b77f5ab645c9 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -21,16 +21,22 @@ application. These properties can be set directly on a
[SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your
`SparkContext`. `SparkConf` allows you to configure some of the common properties
(e.g. master URL and application name), as well as arbitrary key-value pairs through the
-`set()` method. For example, we could initialize an application as follows:
+`set()` method. For example, we could initialize an application with two threads as follows:
+
+Note that we run with local[2], meaning two threads - which represents "minimal" parallelism,
+which can help detect bugs that only exist when we run in a distributed context.
{% highlight scala %}
val conf = new SparkConf()
- .setMaster("local")
+ .setMaster("local[2]")
.setAppName("CountingSheep")
.set("spark.executor.memory", "1g")
val sc = new SparkContext(conf)
{% endhighlight %}
+Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually
+require one to prevent any sort of starvation issues.
+
## Dynamically Loading Spark Properties
In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For
instance, if you'd like to run the same application with different masters or different
@@ -46,7 +52,7 @@ Then, you can supply configuration values at runtime:
--conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar
{% endhighlight %}
-The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit)
+The Spark shell and [`spark-submit`](submitting-applications.html)
tool support two ways to load configurations dynamically. The first are command line options,
such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf`
flag, but uses special flags for properties that play a part in launching the Spark application.
@@ -103,6 +109,26 @@ of the most common options to set are:
(e.g. 512m, 2g).
+
+
spark.driver.memory
+
512m
+
+ Amount of memory to use for the driver process, i.e. where SparkContext is initialized.
+ (e.g. 512m, 2g).
+
+
+
+
spark.driver.maxResultSize
+
1g
+
+ Limit of total size of serialized results of all partitions for each Spark action (e.g. collect).
+ Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size
+ is above this limit.
+ Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory
+ and memory overhead of objects in JVM). Setting a proper limit can protect the driver from
+ out-of-memory errors.
+
+
spark.serializer
org.apache.spark.serializer. JavaSerializer
@@ -116,12 +142,23 @@ of the most common options to set are:
org.apache.spark.Serializer.
+
+
spark.kryo.classesToRegister
+
(none)
+
+ If you use Kryo serialization, give a comma-separated list of custom class names to register
+ with Kryo.
+ See the tuning guide for more details.
+
+
spark.kryo.registrator
(none)
- If you use Kryo serialization, set this class to register your custom classes with Kryo.
- It should be set to a class that extends
+ If you use Kryo serialization, set this class to register your custom classes with Kryo. This
+ property is useful if you need to register your classes in a custom way, e.g. to specify a custom
+ field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be
+ set to a class that extends
KryoRegistrator.
See the tuning guide for more details.
@@ -153,14 +190,6 @@ Apart from these, the following properties are also available, and may be useful
#### Runtime Environment
Property Name
Default
Meaning
-
-
spark.executor.memory
-
512m
-
- Amount of memory to use per executor process, in the same format as JVM memory strings
- (e.g. 512m, 2g).
-
-
spark.executor.extraJavaOptions
(none)
@@ -195,6 +224,7 @@ Apart from these, the following properties are also available, and may be useful
(Experimental) Whether to give user-added jars precedence over Spark's own jars when
loading classes in Executors. This feature can be used to mitigate conflicts between
Spark's dependencies and user dependencies. It is currently an experimental feature.
+ (Currently, this setting does not work for YARN, see SPARK-2996 for more details).
@@ -348,6 +378,16 @@ Apart from these, the following properties are also available, and may be useful
map-side aggregation and there are at most this many reduce partitions.
+
+
spark.shuffle.blockTransferService
+
netty
+
+ Implementation to use for transferring shuffle and cached blocks between executors. There
+ are two implementations available: netty and nio. Netty-based
+ block transfer is intended to be simpler but equally efficient and is the default option
+ starting in 1.2.
+
+
#### Spark UI
@@ -357,14 +397,23 @@ Apart from these, the following properties are also available, and may be useful
spark.ui.port
4040
- Port for your application's dashboard, which shows memory and workload data
+ Port for your application's dashboard, which shows memory and workload data.
spark.ui.retainedStages
1000
- How many stages the Spark UI remembers before garbage collecting.
+ How many stages the Spark UI and status APIs remember before garbage
+ collecting.
+
+
+
+
spark.ui.retainedJobs
+
1000
+
+ How many stages the Spark UI and status APIs remember before garbage
+ collecting.
@@ -514,6 +563,9 @@ Apart from these, the following properties are also available, and may be useful
spark.default.parallelism
+ For distributed shuffle operations like reduceByKey and join, the
+ largest number of partitions in a parent RDD. For operations like parallelize
+ with no parent RDDs, it depends on the cluster manager:
Local mode: number of cores on the local machine
Mesos fine grained mode: 8
@@ -521,8 +573,8 @@ Apart from these, the following properties are also available, and may be useful
- Default number of tasks to use across the cluster for distributed shuffle operations
- (groupByKey, reduceByKey, etc) when not set by user.
+ Default number of partitions in RDDs returned by transformations like join,
+ reduceByKey, and parallelize when not set by user.
@@ -619,6 +671,15 @@ Apart from these, the following properties are also available, and may be useful
output directories. We recommend that users do not disable this except if trying to achieve compatibility with
previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand.
+
+
spark.hadoop.cloneConf
+
false
+
If set to true, clones a new Hadoop Configuration object for each task. This
+ option should be enabled to work around Configuration thread-safety issues (see
+ SPARK-2546 for more details).
+ This is disabled by default in order to avoid unexpected performance regressions for jobs that
+ are not affected by these issues.
+
spark.executor.heartbeatInterval
10000
@@ -717,7 +778,7 @@ Apart from these, the following properties are also available, and may be useful
spark.akka.heartbeat.pauses
-
600
+
6000
This is set to a larger value to disable failure detector that comes inbuilt akka. It can be
enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause
@@ -872,8 +933,8 @@ Apart from these, the following properties are also available, and may be useful
spark.scheduler.revive.interval
1000
- The interval length for the scheduler to revive the worker resource offers to run tasks.
- (in milliseconds)
+ The interval length for the scheduler to revive the worker resource offers to run tasks
+ (in milliseconds).
@@ -885,7 +946,7 @@ Apart from these, the following properties are also available, and may be useful
to wait for before scheduling begins. Specified as a double between 0 and 1.
Regardless of whether the minimum ratio of resources has been reached,
the maximum amount of time it will wait before scheduling begins is controlled by config
- spark.scheduler.maxRegisteredResourcesWaitingTime
+ spark.scheduler.maxRegisteredResourcesWaitingTime.
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index 530798f2b8022..66bf5f1a855ed 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -12,16 +12,14 @@ on the [Amazon Web Services site](http://aws.amazon.com/).
`spark-ec2` is designed to manage multiple named clusters. You can
launch a new cluster (telling the script its size and giving it a name),
-shutdown an existing cluster, or log into a cluster. Each cluster
-launches a set of instances, which are tagged with the cluster name,
-and placed into EC2 security groups. If you don't specify a security
-group, the `spark-ec2` script will create security groups based on the
-cluster name you request. For example, a cluster named
+shutdown an existing cluster, or log into a cluster. Each cluster is
+identified by placing its machines into EC2 security groups whose names
+are derived from the name of the cluster. For example, a cluster named
`test` will contain a master node in a security group called
`test-master`, and a number of slave nodes in a security group called
-`test-slaves`. You can also specify a security group prefix to be used
-in place of the cluster name. Machines in a cluster can be identified
-by looking for the "Name" tag of the instance in the Amazon EC2 Console.
+`test-slaves`. The `spark-ec2` script will create these security groups
+for you based on the cluster name you request. You can also use them to
+identify machines belonging to each cluster in the Amazon EC2 Console.
# Before You Start
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index fdb9f98e214e5..e298c51f8a5b7 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -6,6 +6,47 @@ title: GraphX Programming Guide
* This will become a table of contents (this text will be scraped).
{:toc}
+
+
+[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD
+[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge
+[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet
+[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph
+[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps
+[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED]
+[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED]
+[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED]
+[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED]
+[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED]
+[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED]
+[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED]
+[Graph.aggregateMessages]: api/scala/index.html#org.apache.spark.graphx.Graph@aggregateMessages[A]((EdgeContext[VD,ED,A])⇒Unit,(A,A)⇒A,TripletFields)(ClassTag[A]):VertexRDD[A]
+[EdgeContext]: api/scala/index.html#org.apache.spark.graphx.EdgeContext
+[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A]
+[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]]
+[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]]
+[RDD Persistence]: programming-guide.html#rdd-persistence
+[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED]
+[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED]
+[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$
+[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int]
+[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED]
+[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int]
+[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED]
+[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy
+[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED]
+[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$
+[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$
+[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$
+[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED]
+[EdgeContext.sendToSrc]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToSrc(msg:A):Unit
+[EdgeContext.sendToDst]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToDst(msg:A):Unit
+[TripletFields]: api/java/org/apache/spark/graphx/TripletFields.html
+[TripletFields.All]: api/java/org/apache/spark/graphx/TripletFields.html#All
+[TripletFields.None]: api/java/org/apache/spark/graphx/TripletFields.html#None
+[TripletFields.Src]: api/java/org/apache/spark/graphx/TripletFields.html#Src
+[TripletFields.Dst]: api/java/org/apache/spark/graphx/TripletFields.html#Dst
+
-
-
-
+1. To improve performance we have introduced a new version of
+[`mapReduceTriplets`][Graph.mapReduceTriplets] called
+[`aggregateMessages`][Graph.aggregateMessages] which takes the messages previously returned from
+[`mapReduceTriplets`][Graph.mapReduceTriplets] through a callback ([`EdgeContext`][EdgeContext])
+rather than by return value.
+We are deprecating [`mapReduceTriplets`][Graph.mapReduceTriplets] and encourage users to consult
+the [transition guide](#mrTripletsTransition).
-However, the same restrictions that enable these substantial performance gains also make it
-difficult to express many of the important stages in a typical graph-analytics pipeline:
-constructing the graph, modifying its structure, or expressing computation that spans multiple
-graphs. Furthermore, how we look at data depends on our objectives and the same raw data may have
-many different table and graph views.
-
-
-
-
-
-
-As a consequence, it is often necessary to be able to move between table and graph views of the same
-physical data and to leverage the properties of each view to easily and efficiently express
-computation. However, existing graph analytics pipelines must compose graph-parallel and data-
-parallel systems, leading to extensive data movement and duplication and a complicated programming
-model.
-
-
-
-
-
-
-The goal of the GraphX project is to unify graph-parallel and data-parallel computation in one
-system with a single composable API. The GraphX API enables users to view data both as a graph and
-as collections (i.e., RDDs) without data movement or duplication. By incorporating recent advances
-in graph-parallel systems, GraphX is able to optimize the execution of graph operations.
-
-## GraphX Replaces the Spark Bagel API
-
-Prior to the release of GraphX, graph computation in Spark was expressed using Bagel, an
-implementation of Pregel. GraphX improves upon Bagel by exposing a richer property graph API, a
-more streamlined version of the Pregel abstraction, and system optimizations to improve performance
-and reduce memory overhead. While we plan to eventually deprecate Bagel, we will continue to
-support the [Bagel API](api/scala/index.html#org.apache.spark.bagel.package) and
-[Bagel programming guide](bagel-programming-guide.html). However, we encourage Bagel users to
-explore the new GraphX API and comment on issues that may complicate the transition from Bagel.
-
-## Migrating from Spark 0.9.1
-
-GraphX in Spark {{site.SPARK_VERSION}} contains one user-facing interface change from Spark 0.9.1. [`EdgeRDD`][EdgeRDD] may now store adjacent vertex attributes to construct the triplets, so it has gained a type parameter. The edges of a graph of type `Graph[VD, ED]` are of type `EdgeRDD[ED, VD]` rather than `EdgeRDD[ED]`.
-
-[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD
+2. In Spark 1.0 and 1.1, the type signature of [`EdgeRDD`][EdgeRDD] switched from
+`EdgeRDD[ED]` to `EdgeRDD[ED, VD]` to enable some caching optimizations. We have since discovered
+a more elegant solution and have restored the type signature to the more natural `EdgeRDD[ED]` type.
# Getting Started
@@ -108,9 +96,10 @@ import org.apache.spark.rdd.RDD
If you are not using the Spark shell you will also need a `SparkContext`. To learn more about
getting started with Spark refer to the [Spark Quick Start Guide](quick-start.html).
-# The Property Graph
+# The Property Graph
+
The [property graph](api/scala/index.html#org.apache.spark.graphx.Graph) is a directed multigraph
with user defined objects attached to each vertex and edge. A directed multigraph is a directed
graph with potentially multiple parallel edges sharing the same source and destination vertex. The
@@ -123,7 +112,7 @@ identifiers.
The property graph is parameterized over the vertex (`VD`) and edge (`ED`) types. These
are the types of the objects associated with each vertex and edge respectively.
-> GraphX optimizes the representation of vertex and edge types when they are plain old data-types
+> GraphX optimizes the representation of vertex and edge types when they are primitive data types
> (e.g., int, double, etc...) reducing the in memory footprint by storing them in specialized
> arrays.
@@ -142,8 +131,8 @@ var graph: Graph[VertexProperty, String] = null
Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Changes to the values or
structure of the graph are accomplished by producing a new graph with the desired changes. Note
that substantial parts of the original graph (i.e., unaffected structure, attributes, and indicies)
-are reused in the new graph reducing the cost of this inherently functional data-structure. The
-graph is partitioned across the executors using a range of vertex-partitioning heuristics. As with
+are reused in the new graph reducing the cost of this inherently functional data structure. The
+graph is partitioned across the executors using a range of vertex partitioning heuristics. As with
RDDs, each partition of the graph can be recreated on a different machine in the event of a failure.
Logically the property graph corresponds to a pair of typed collections (RDDs) encoding the
@@ -153,12 +142,12 @@ the vertices and edges of the graph:
{% highlight scala %}
class Graph[VD, ED] {
val vertices: VertexRDD[VD]
- val edges: EdgeRDD[ED, VD]
+ val edges: EdgeRDD[ED]
}
{% endhighlight %}
-The classes `VertexRDD[VD]` and `EdgeRDD[ED, VD]` extend and are optimized versions of `RDD[(VertexID,
-VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED, VD]` provide additional
+The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID,
+VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional
functionality built around graph computation and leverage internal optimizations. We discuss the
`VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge
RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form:
@@ -211,7 +200,6 @@ In the above example we make use of the [`Edge`][Edge] case class. Edges have a
`dstId` corresponding to the source and destination vertex identifiers. In addition, the `Edge`
class has an `attr` member which stores the edge property.
-[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge
We can deconstruct a graph into the respective vertex and edge views by using the `graph.vertices`
and `graph.edges` members respectively.
@@ -237,7 +225,6 @@ The triplet view logically joins the vertex and edge properties yielding an
`RDD[EdgeTriplet[VD, ED]]` containing instances of the [`EdgeTriplet`][EdgeTriplet] class. This
*join* can be expressed in the following SQL expression:
-[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet
{% highlight sql %}
SELECT src.id, dst.id, src.attr, e.attr, dst.attr
@@ -278,9 +265,6 @@ core operators are defined in [`GraphOps`][GraphOps]. However, thanks to Scala
operators in `GraphOps` are automatically available as members of `Graph`. For example, we can
compute the in-degree of each vertex (defined in `GraphOps`) by the following:
-[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph
-[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps
-
{% highlight scala %}
val graph: Graph[(String, String), String]
// Use the implicit GraphOps.inDegrees operator
@@ -310,7 +294,7 @@ class Graph[VD, ED] {
val degrees: VertexRDD[Int]
// Views of the graph as collections =============================================================
val vertices: VertexRDD[VD]
- val edges: EdgeRDD[ED, VD]
+ val edges: EdgeRDD[ED]
val triplets: RDD[EdgeTriplet[VD, ED]]
// Functions for caching graphs ==================================================================
def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED]
@@ -341,10 +325,10 @@ class Graph[VD, ED] {
// Aggregate information about adjacent triplets =================================================
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]]
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]]
- def mapReduceTriplets[A: ClassTag](
- mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)],
- reduceFunc: (A, A) => A,
- activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None)
+ def aggregateMessages[Msg: ClassTag](
+ sendMsg: EdgeContext[VD, ED, Msg] => Unit,
+ mergeMsg: (Msg, Msg) => Msg,
+ tripletFields: TripletFields = TripletFields.All)
: VertexRDD[A]
// Iterative graph-parallel computation ==========================================================
def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)(
@@ -363,8 +347,7 @@ class Graph[VD, ED] {
## Property Operators
-In direct analogy to the RDD `map` operator, the property
-graph contains the following:
+Like the RDD `map` operator, the property graph contains the following:
{% highlight scala %}
class Graph[VD, ED] {
@@ -377,7 +360,7 @@ class Graph[VD, ED] {
Each of these operators yields a new graph with the vertex or edge properties modified by the user
defined `map` function.
-> Note that in all cases the graph structure is unaffected. This is a key feature of these operators
+> Note that in each case the graph structure is unaffected. This is a key feature of these operators
> which allows the resulting graph to reuse the structural indices of the original graph. The
> following snippets are logically equivalent, but the first one does not preserve the structural
> indices and would not benefit from the GraphX system optimizations:
@@ -390,14 +373,13 @@ val newGraph = Graph(newVertices, graph.edges)
val newGraph = graph.mapVertices((id, attr) => mapUdf(id, attr))
{% endhighlight %}
-[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED]
These operators are often used to initialize the graph for a particular computation or project away
-unnecessary properties. For example, given a graph with the out-degrees as the vertex properties
+unnecessary properties. For example, given a graph with the out degrees as the vertex properties
(we describe how to construct such a graph later), we initialize it for PageRank:
{% highlight scala %}
-// Given a graph where the vertex property is the out-degree
+// Given a graph where the vertex property is the out degree
val inputGraph: Graph[Int, String] =
graph.outerJoinVertices(graph.outDegrees)((vid, _, degOpt) => degOpt.getOrElse(0))
// Construct a graph where each edge contains the weight
@@ -406,9 +388,10 @@ val outputGraph: Graph[Double, Double] =
inputGraph.mapTriplets(triplet => 1.0 / triplet.srcAttr).mapVertices((id, _) => 1.0)
{% endhighlight %}
-## Structural Operators
+## Structural Operators
+
Currently GraphX supports only a simple set of commonly used structural operators and we expect to
add more in the future. The following is a list of the basic structural operators.
@@ -425,9 +408,8 @@ class Graph[VD, ED] {
The [`reverse`][Graph.reverse] operator returns a new graph with all the edge directions reversed.
This can be useful when, for example, trying to compute the inverse PageRank. Because the reverse
operation does not modify vertex or edge properties or change the number of edges, it can be
-implemented efficiently without data-movement or duplication.
+implemented efficiently without data movement or duplication.
-[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED]
The [`subgraph`][Graph.subgraph] operator takes vertex and edge predicates and returns the graph
containing only the vertices that satisfy the vertex predicate (evaluate to true) and edges that
@@ -435,7 +417,6 @@ satisfy the edge predicate *and connect vertices that satisfy the vertex predica
operator can be used in number of situations to restrict the graph to the vertices and edges of
interest or eliminate broken links. For example in the following code we remove broken links:
-[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED]
{% highlight scala %}
// Create an RDD for the vertices
@@ -469,13 +450,12 @@ validGraph.triplets.map(
> Note in the above example only the vertex predicate is provided. The `subgraph` operator defaults
> to `true` if the vertex or edge predicates are not provided.
-The [`mask`][Graph.mask] operator also constructs a subgraph by returning a graph that contains the
+The [`mask`][Graph.mask] operator constructs a subgraph by returning a graph that contains the
vertices and edges that are also found in the input graph. This can be used in conjunction with the
`subgraph` operator to restrict a graph based on the properties in another related graph. For
example, we might run connected components using the graph with missing vertices and then restrict
the answer to the valid subgraph.
-[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED]
{% highlight scala %}
// Run Connected Components
@@ -490,10 +470,9 @@ The [`groupEdges`][Graph.groupEdges] operator merges parallel edges (i.e., dupli
pairs of vertices) in the multigraph. In many numerical applications, parallel edges can be *added*
(their weights combined) into a single edge thereby reducing the size of the graph.
-[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED]
+
## Join Operators
-
In many cases it is necessary to join data from external collections (RDDs) with graphs. For
example, we might have extra user properties that we want to merge with an existing graph or we
@@ -514,10 +493,8 @@ returns a new graph with the vertex properties obtained by applying the user def
to the result of the joined vertices. Vertices without a matching value in the RDD retain their
original value.
-[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED]
-
-> Note that if the RDD contains more than one value for a given vertex only one will be used. It
-> is therefore recommended that the input RDD be first made unique using the following which will
+> Note that if the RDD contains more than one value for a given vertex only one will be used. It
+> is therefore recommended that the input RDD be made unique using the following which will
> also *pre-index* the resulting values to substantially accelerate the subsequent join.
> {% highlight scala %}
val nonUniqueCosts: RDD[(VertexID, Double)]
@@ -533,8 +510,6 @@ property type. Because not all vertices may have a matching value in the input
function takes an `Option` type. For example, we can setup a graph for PageRank by initializing
vertex properties with their `outDegree`.
-[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED]
-
{% highlight scala %}
val outDegrees: VertexRDD[Int] = graph.outDegrees
@@ -555,65 +530,76 @@ val joinedGraph = graph.joinVertices(uniqueCosts,
(id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost)
{% endhighlight %}
+>
+
+
## Neighborhood Aggregation
-A key part of graph computation is aggregating information about the neighborhood of each vertex.
-For example we might want to know the number of followers each user has or the average age of the
+A key step in may graph analytics tasks is aggregating information about the neighborhood of each
+vertex.
+For example, we might want to know the number of followers each user has or the average age of the
the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and
connected components) repeatedly aggregate properties of neighboring vertices (e.g., current
PageRank Value, shortest path to the source, and smallest reachable vertex id).
-### Map Reduce Triplets (mapReduceTriplets)
-
+> To improve performance the primary aggregation operator changed from
+`graph.mapReduceTriplets` to the new `graph.AggregateMessages`. While the changes in the API are
+relatively small, we provide a transition guide below.
-[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A]
+
-The core (heavily optimized) aggregation primitive in GraphX is the
-[`mapReduceTriplets`][Graph.mapReduceTriplets] operator:
+### Aggregate Messages (aggregateMessages)
+
+The core aggregation operation in GraphX is [`aggregateMessages`][Graph.aggregateMessages].
+This operator applies a user defined `sendMsg` function to each edge triplet in the graph
+and then uses the `mergeMsg` function to aggregate those messages at their destination vertex.
{% highlight scala %}
class Graph[VD, ED] {
- def mapReduceTriplets[A](
- map: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
- reduce: (A, A) => A)
- : VertexRDD[A]
+ def aggregateMessages[Msg: ClassTag](
+ sendMsg: EdgeContext[VD, ED, Msg] => Unit,
+ mergeMsg: (Msg, Msg) => Msg,
+ tripletFields: TripletFields = TripletFields.All)
+ : VertexRDD[Msg]
}
{% endhighlight %}
-The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which
-is applied to each triplet and can yield *messages* destined to either (none or both) vertices in
-the triplet. To facilitate optimized pre-aggregation, we currently only support messages destined
-to the source or destination vertex of the triplet. The user defined `reduce` function combines the
-messages destined to each vertex. The `mapReduceTriplets` operator returns a `VertexRDD[A]`
-containing the aggregate message (of type `A`) destined to each vertex. Vertices that do not
+The user defined `sendMsg` function takes an [`EdgeContext`][EdgeContext], which exposes the
+source and destination attributes along with the edge attribute and functions
+([`sendToSrc`][EdgeContext.sendToSrc], and [`sendToDst`][EdgeContext.sendToDst]) to send
+messages to the source and destination attributes. Think of `sendMsg` as the map
+function in map-reduce.
+The user defined `mergeMsg` function takes two messages destined to the same vertex and
+yields a single message. Think of `mergeMsg` as the reduce function in map-reduce.
+The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]`
+containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not
receive a message are not included in the returned `VertexRDD`.
-
-
-
Note that mapReduceTriplets takes an additional optional activeSet
-(not shown above see API docs for details) which restricts the map phase to edges adjacent to the
-vertices in the provided VertexRDD:
The EdgeDirection specifies which edges adjacent to the vertex set are included in the map
-phase. If the direction is In, then the user defined map function will
-only be run only on edges with the destination vertex in the active set. If the direction is
-Out, then the map function will only be run only on edges originating from
-vertices in the active set. If the direction is Either, then the map
-function will be run only on edges with either vertex in the active set. If the direction is
-Both, then the map function will be run only on edges with both vertices
-in the active set. The active set must be derived from the set of vertices in the graph.
-Restricting computation to triplets adjacent to a subset of the vertices is often necessary in
-incremental iterative computation and is a key part of the GraphX implementation of Pregel.
-
-
-
-In the following example we use the `mapReduceTriplets` operator to compute the average age of the
-more senior followers of each user.
+
+
+In addition, [`aggregateMessages`][Graph.aggregateMessages] takes an optional
+`tripletsFields` which indicates what data is accessed in the [`EdgeContext`][EdgeContext]
+(i.e., the source vertex attribute but not the destination vertex attribute).
+The possible options for the `tripletsFields` are defined in [`TripletFields`][TripletFields] and
+the default value is [`TripletFields.All`][TripletFields.All] which indicates that the user
+defined `sendMsg` function may access any of the fields in the [`EdgeContext`][EdgeContext].
+The `tripletFields` argument can be used to notify GraphX that only part of the
+[`EdgeContext`][EdgeContext] will be needed allowing GraphX to select an optimized join strategy.
+For example if we are computing the average age of the followers of each user we would only require
+the source field and so we would use [`TripletFields.Src`][TripletFields.Src] to indicate that we
+only require the source field
+
+> In earlier versions of GraphX we used byte code inspection to infer the
+[`TripletFields`][TripletFields] however we have found that bytecode inspection to be
+slightly unreliable and instead opted for more explicit user control.
+
+In the following example we use the [`aggregateMessages`][Graph.aggregateMessages] operator to
+compute the average age of the more senior followers of each user.
{% highlight scala %}
// Import random graph generation library
@@ -622,14 +608,11 @@ import org.apache.spark.graphx.util.GraphGenerators
val graph: Graph[Double, Int] =
GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble )
// Compute the number of older followers and their total age
-val olderFollowers: VertexRDD[(Int, Double)] = graph.mapReduceTriplets[(Int, Double)](
+val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)](
triplet => { // Map Function
if (triplet.srcAttr > triplet.dstAttr) {
// Send message to destination vertex containing counter and age
- Iterator((triplet.dstId, (1, triplet.srcAttr)))
- } else {
- // Don't send a message for this triplet
- Iterator.empty
+ triplet.sendToDst(1, triplet.srcAttr)
}
},
// Add counter and age
@@ -642,10 +625,57 @@ val avgAgeOfOlderFollowers: VertexRDD[Double] =
avgAgeOfOlderFollowers.collect.foreach(println(_))
{% endhighlight %}
-> Note that the `mapReduceTriplets` operation performs optimally when the messages (and the sums of
-> messages) are constant sized (e.g., floats and addition instead of lists and concatenation). More
-> precisely, the result of `mapReduceTriplets` should ideally be sub-linear in the degree of each
-> vertex.
+> The `aggregateMessages` operation performs optimally when the messages (and the sums of
+> messages) are constant sized (e.g., floats and addition instead of lists and concatenation).
+
+
+
+### Map Reduce Triplets Transition Guide (Legacy)
+
+In earlier versions of GraphX we neighborhood aggregation was accomplished using the
+[`mapReduceTriplets`][Graph.mapReduceTriplets] operator:
+
+{% highlight scala %}
+class Graph[VD, ED] {
+ def mapReduceTriplets[Msg](
+ map: EdgeTriplet[VD, ED] => Iterator[(VertexId, Msg)],
+ reduce: (Msg, Msg) => Msg)
+ : VertexRDD[Msg]
+}
+{% endhighlight %}
+
+The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which
+is applied to each triplet and can yield *messages* which are aggregated using the user defined
+`reduce` function.
+However, we found the user of the returned iterator to be expensive and it inhibited our ability to
+apply additional optimizations (e.g., local vertex renumbering).
+In [`aggregateMessages`][Graph.aggregateMessages] we introduced the EdgeContext which exposes the
+triplet fields and also functions to explicitly send messages to the source and destination vertex.
+Furthermore we removed bytecode inspection and instead require the user to indicate what fields
+in the triplet are actually required.
+
+The following code block using `mapReduceTriplets`:
+
+{% highlight scala %}
+val graph: Graph[Int, Float] = ...
+def msgFun(triplet: Triplet[Int, Float]): Iterator[(Int, String)] = {
+ Iterator((triplet.dstId, "Hi"))
+}
+def reduceFun(a: Int, b: Int): Int = a + b
+val result = graph.mapReduceTriplets[String](msgFun, reduceFun)
+{% endhighlight %}
+
+can be rewritten using `aggregateMessages` as:
+
+{% highlight scala %}
+val graph: Graph[Int, Float] = ...
+def msgFun(triplet: EdgeContext[Int, Float, String]) {
+ triplet.sendToDst("Hi")
+}
+def reduceFun(a: Int, b: Int): Int = a + b
+val result = graph.aggregateMessages[String](msgFun, reduceFun)
+{% endhighlight %}
+
### Computing Degree Information
@@ -673,10 +703,6 @@ attributes at each vertex. This can be easily accomplished using the
[`collectNeighborIds`][GraphOps.collectNeighborIds] and the
[`collectNeighbors`][GraphOps.collectNeighbors] operators.
-[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]]
-[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]]
-
-
{% highlight scala %}
class GraphOps[VD, ED] {
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]]
@@ -684,36 +710,34 @@ class GraphOps[VD, ED] {
}
{% endhighlight %}
-> Note that these operators can be quite costly as they duplicate information and require
+> These operators can be quite costly as they duplicate information and require
> substantial communication. If possible try expressing the same computation using the
-> `mapReduceTriplets` operator directly.
+> [`aggregateMessages`][Graph.aggregateMessages] operator directly.
## Caching and Uncaching
In Spark, RDDs are not persisted in memory by default. To avoid recomputation, they must be explicitly cached when using them multiple times (see the [Spark Programming Guide][RDD Persistence]). Graphs in GraphX behave the same way. **When using a graph multiple times, make sure to call [`Graph.cache()`][Graph.cache] on it first.**
-[RDD Persistence]: programming-guide.html#rdd-persistence
-[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED]
In iterative computations, *uncaching* may also be necessary for best performance. By default, cached RDDs and graphs will remain in memory until memory pressure forces them to be evicted in LRU order. For iterative computation, intermediate results from previous iterations will fill up the cache. Though they will eventually be evicted, the unnecessary data stored in memory will slow down garbage collection. It would be more efficient to uncache intermediate results as soon as they are no longer necessary. This involves materializing (caching and forcing) a graph or RDD every iteration, uncaching all other datasets, and only using the materialized dataset in future iterations. However, because graphs are composed of multiple RDDs, it can be difficult to unpersist them correctly. **For iterative computation we recommend using the Pregel API, which correctly unpersists intermediate results.**
-# Pregel API
-Graphs are inherently recursive data-structures as properties of vertices depend on properties of
+# Pregel API
+
+Graphs are inherently recursive data structures as properties of vertices depend on properties of
their neighbors which in turn depend on properties of *their* neighbors. As a
consequence many important graph algorithms iteratively recompute the properties of each vertex
until a fixed-point condition is reached. A range of graph-parallel abstractions have been proposed
-to express these iterative algorithms. GraphX exposes a Pregel-like operator which is a fusion of
-the widely used Pregel and GraphLab abstractions.
+to express these iterative algorithms. GraphX exposes a variant of the Pregel API.
-At a high-level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction
-*constrained to the topology of the graph*. The Pregel operator executes in a series of super-steps
-in which vertices receive the *sum* of their inbound messages from the previous super- step, compute
+At a high level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction
+*constrained to the topology of the graph*. The Pregel operator executes in a series of super steps
+in which vertices receive the *sum* of their inbound messages from the previous super step, compute
a new value for the vertex property, and then send messages to neighboring vertices in the next
-super-step. Unlike Pregel and instead more like GraphLab messages are computed in parallel as a
+super step. Unlike Pregel, messages are computed in parallel as a
function of the edge triplet and the message computation has access to both the source and
-destination vertex attributes. Vertices that do not receive a message are skipped within a super-
+destination vertex attributes. Vertices that do not receive a message are skipped within a super
step. The Pregel operators terminates iteration and returns the final graph when there are no
messages remaining.
@@ -724,8 +748,6 @@ messages remaining.
The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch*
of its implementation (note calls to graph.cache have been removed):
-[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED]
-
{% highlight scala %}
class GraphOps[VD, ED] {
def pregel[A]
@@ -795,9 +817,10 @@ val sssp = initialGraph.pregel(Double.PositiveInfinity)(
println(sssp.vertices.collect.mkString("\n"))
{% endhighlight %}
-# Graph Builders
+# Graph Builders
+
GraphX provides several ways of building a graph from a collection of vertices and edges in an RDD or on disk. None of the graph builders repartitions the graph's edges by default; instead, edges are left in their default partitions (such as their original blocks in HDFS). [`Graph.groupEdges`][Graph.groupEdges] requires the graph to be repartitioned because it assumes identical edges will be colocated on the same partition, so you must call [`Graph.partitionBy`][Graph.partitionBy] before calling `groupEdges`.
{% highlight scala %}
@@ -848,18 +871,12 @@ object Graph {
[`Graph.fromEdgeTuples`][Graph.fromEdgeTuples] allows creating a graph from only an RDD of edge tuples, assigning the edges the value 1, and automatically creating any vertices mentioned by edges and assigning them the default value. It also supports deduplicating the edges; to deduplicate, pass `Some` of a [`PartitionStrategy`][PartitionStrategy] as the `uniqueEdges` parameter (for example, `uniqueEdges = Some(PartitionStrategy.RandomVertexCut)`). A partition strategy is necessary to colocate identical edges on the same partition so they can be deduplicated.
-[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$
-
-[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int]
-[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED]
-[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int]
-[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED]
+
# Vertex and Edge RDDs
-
GraphX exposes `RDD` views of the vertices and edges stored within the graph. However, because
-GraphX maintains the vertices and edges in optimized data-structures and these data-structures
+GraphX maintains the vertices and edges in optimized data structures and these data structures
provide additional functionality, the vertices and edges are returned as `VertexRDD` and `EdgeRDD`
respectively. In this section we review some of the additional useful functionality in these types.
@@ -870,7 +887,7 @@ The `VertexRDD[A]` extends `RDD[(VertexID, A)]` and adds the additional constrai
attribute of type `A`. Internally, this is achieved by storing the vertex attributes in a reusable
hash-map data-structure. As a consequence if two `VertexRDD`s are derived from the same base
`VertexRDD` (e.g., by `filter` or `mapValues`) they can be joined in constant time without hash
-evaluations. To leverage this indexed data-structure, the `VertexRDD` exposes the following
+evaluations. To leverage this indexed data structure, the `VertexRDD` exposes the following
additional functionality:
{% highlight scala %}
@@ -893,7 +910,7 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] {
Notice, for example, how the `filter` operator returns an `VertexRDD`. Filter is actually
implemented using a `BitSet` thereby reusing the index and preserving the ability to do fast joins
with other `VertexRDD`s. Likewise, the `mapValues` operators do not allow the `map` function to
-change the `VertexID` thereby enabling the same `HashMap` data-structures to be reused. Both the
+change the `VertexID` thereby enabling the same `HashMap` data structures to be reused. Both the
`leftJoin` and `innerJoin` are able to identify when joining two `VertexRDD`s derived from the same
`HashMap` and implement the join by linear scan rather than costly point lookups.
@@ -916,21 +933,19 @@ val setC: VertexRDD[Double] = setA.innerJoin(setB)((id, a, b) => a + b)
## EdgeRDDs
-The `EdgeRDD[ED, VD]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one
+The `EdgeRDD[ED]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one
of the various partitioning strategies defined in [`PartitionStrategy`][PartitionStrategy]. Within
each partition, edge attributes and adjacency structure, are stored separately enabling maximum
reuse when changing attribute values.
-[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy
-
The three additional functions exposed by the `EdgeRDD` are:
{% highlight scala %}
// Transform the edge attributes while preserving the structure
-def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2, VD]
+def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2]
// Revere the edges reusing both attributes and structure
-def reverse: EdgeRDD[ED, VD]
+def reverse: EdgeRDD[ED]
// Join two `EdgeRDD`s partitioned using the same partitioning strategy.
-def innerJoin[ED2, ED3](other: EdgeRDD[ED2, VD])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD]
+def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3]
{% endhighlight %}
In most applications we have found that operations on the `EdgeRDD` are accomplished through the
@@ -960,7 +975,6 @@ the [`Graph.partitionBy`][Graph.partitionBy] operator. The default partitioning
the initial partitioning of the edges as provided on graph construction. However, users can easily
switch to 2D-partitioning or other heuristics included in GraphX.
-[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED]
+# Graph Algorithms
+
GraphX includes a set of graph algorithms to simplify analytics tasks. The algorithms are contained in the `org.apache.spark.graphx.lib` package and can be accessed directly as methods on `Graph` via [`GraphOps`][GraphOps]. This section describes the algorithms and how they are used.
-## PageRank
+## PageRank
+
PageRank measures the importance of each vertex in a graph, assuming an edge from *u* to *v* represents an endorsement of *v*'s importance by *u*. For example, if a Twitter user is followed by many others, the user will be ranked highly.
GraphX comes with static and dynamic implementations of PageRank as methods on the [`PageRank` object][PageRank]. Static PageRank runs for a fixed number of iterations, while dynamic PageRank runs until the ranks converge (i.e., stop changing by more than a specified tolerance). [`GraphOps`][GraphOps] allows calling these algorithms directly as methods on `Graph`.
GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `graphx/data/users.txt`, and a set of relationships between users is given in `graphx/data/followers.txt`. We compute the PageRank of each user as follows:
-[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$
-
{% highlight scala %}
// Load the edges as a graph
val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
@@ -1014,8 +1028,6 @@ println(ranksByUsername.collect().mkString("\n"))
The connected components algorithm labels each connected component of the graph with the ID of its lowest-numbered vertex. For example, in a social network, connected components can approximate clusters. GraphX contains an implementation of the algorithm in the [`ConnectedComponents` object][ConnectedComponents], and we compute the connected components of the example social network dataset from the [PageRank section](#pagerank) as follows:
-[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$
-
{% highlight scala %}
// Load the graph as in the PageRank example
val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
@@ -1037,9 +1049,6 @@ println(ccByUsername.collect().mkString("\n"))
A vertex is part of a triangle when it has two adjacent vertices with an edge between them. GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] that determines the number of triangles passing through each vertex, providing a measure of clustering. We compute the triangle count of the social network dataset from the [PageRank section](#pagerank). *Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].*
-[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$
-[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED]
-
{% highlight scala %}
// Load the edges in canonical order and partition the graph for triangle count
val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt", true).partitionBy(PartitionStrategy.RandomVertexCut)
diff --git a/docs/img/data_parallel_vs_graph_parallel.png b/docs/img/data_parallel_vs_graph_parallel.png
deleted file mode 100644
index d3918f01d8f3b..0000000000000
Binary files a/docs/img/data_parallel_vs_graph_parallel.png and /dev/null differ
diff --git a/docs/img/graph_analytics_pipeline.png b/docs/img/graph_analytics_pipeline.png
deleted file mode 100644
index 6d606e01894ae..0000000000000
Binary files a/docs/img/graph_analytics_pipeline.png and /dev/null differ
diff --git a/docs/img/tables_and_graphs.png b/docs/img/tables_and_graphs.png
deleted file mode 100644
index ec37bb45a62f0..0000000000000
Binary files a/docs/img/tables_and_graphs.png and /dev/null differ
diff --git a/docs/index.md b/docs/index.md
index edd622ec90f64..171d6ddad62f3 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -112,6 +112,7 @@ options for deployment:
**External Resources:**
* [Spark Homepage](http://spark.apache.org)
+* [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK)
* [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here
* [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and
exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/),
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index d10bd63746629..c696ae9c8e8c8 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
-## Examples
+### Examples
@@ -69,7 +69,7 @@ println("Within Set Sum of Squared Errors = " + WSSSE)
All of MLlib's methods use Java-friendly types, so you can import and call them there the same
way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
-calling `.rdd()` on your `JavaRDD` object. A standalone application example
+calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given below:
{% highlight java %}
@@ -113,12 +113,6 @@ public class KMeansExample {
}
}
{% endhighlight %}
-
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
@@ -153,3 +147,103 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
+
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+Quick Start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
+
+## Streaming clustering
+
+When data arrive in a stream, we may want to estimate clusters dynamically,
+updating them as new data arrive. MLlib provides support for streaming k-means clustering,
+with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm
+uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign
+all points to their nearest cluster, compute new cluster centers, then update each cluster using:
+
+`\begin{equation}
+ c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
+\end{equation}`
+`\begin{equation}
+ n_{t+1} = n_t + m_t
+\end{equation}`
+
+Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned
+to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$`
+is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
+can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
+with `$\alpha$=0` only the most recent data will be used. This is analogous to an
+exponentially-weighted moving average.
+
+The decay can be specified using a `halfLife` parameter, which determines the
+correct decay factor `a` such that, for data acquired
+at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
+The unit of time can be specified either as `batches` or `points` and the update rule
+will be adjusted accordingly.
+
+### Examples
+
+This example shows how to estimate clusters on streaming data.
+
+
+
+
+
+First we import the neccessary classes.
+
+{% highlight scala %}
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+
+{% endhighlight %}
+
+Then we make an input stream of vectors for training, as well as a stream of labeled data
+points for testing. We assume a StreamingContext `ssc` has been created, see
+[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.
+
+{% highlight scala %}
+
+val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse)
+val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse)
+
+{% endhighlight %}
+
+We create a model with random clusters and specify the number of clusters to find
+
+{% highlight scala %}
+
+val numDimensions = 3
+val numClusters = 2
+val model = new StreamingKMeans()
+ .setK(numClusters)
+ .setDecayFactor(1.0)
+ .setRandomCenters(numDimensions, 0.0)
+
+{% endhighlight %}
+
+Now register the streams for training and testing and start the job, printing
+the predicted cluster assignments on new data points as they arrive.
+
+{% highlight scala %}
+
+model.trainOn(trainingData)
+model.predictOnValues(testData).print()
+
+ssc.start()
+ssc.awaitTermination()
+
+{% endhighlight %}
+
+As you add new text files with data the cluster centers will update. Each training
+point should be formatted as `[x1, x2, x3]`, and each test data point
+should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier
+(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir`
+the model will update. Anytime a text file is placed in `/testing/data/dir`
+you will see predictions. With new data, the cluster centers will change!
+
+
+
+
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index d5c539db791be..2094963392295 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -110,7 +110,7 @@ val model = ALS.trainImplicit(ratings, rank, numIterations, alpha)
All of MLlib's methods use Java-friendly types, so you can import and call them there the same
way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
-calling `.rdd()` on your `JavaRDD` object. A standalone application example
+calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given bellow:
{% highlight java %}
@@ -184,12 +184,6 @@ public class CollaborativeFiltering {
}
}
{% endhighlight %}
-
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+Quick Start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
+
## Tutorial
The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for
diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md
index 21cb35b4270ca..870fed6cc5024 100644
--- a/docs/mllib-dimensionality-reduction.md
+++ b/docs/mllib-dimensionality-reduction.md
@@ -121,9 +121,9 @@ public class SVD {
The same code applies to `IndexedRowMatrix` if `U` is defined as an
`IndexedRowMatrix`.
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
+In order to run the above application, follow the instructions
+provided in the [Self-Contained
+Applications](quick-start.html#self-contained-applications) section of the Spark
quick-start guide. Be sure to also include *spark-mllib* to your build file as
a dependency.
@@ -200,10 +200,11 @@ public class PCA {
}
{% endhighlight %}
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+quick-start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 1511ae6dda4ed..197bc77d506c6 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -83,7 +83,7 @@ val idf = new IDF().fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}
-MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
+MLlib's IDF implementation provides an option for ignoring terms which occur in less than a
minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
can be used by passing the `minDocFreq` value to the IDF constructor.
@@ -95,8 +95,49 @@ tf.cache()
val idf = new IDF(minDocFreq = 2).fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}
+
+
+
+TF and IDF are implemented in [HashingTF](api/python/pyspark.mllib.html#pyspark.mllib.feature.HashingTF)
+and [IDF](api/python/pyspark.mllib.html#pyspark.mllib.feature.IDF).
+`HashingTF` takes an RDD of list as the input.
+Each record could be an iterable of strings or other types.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.mllib.feature import HashingTF
+
+sc = SparkContext()
+# Load documents (one per line).
+documents = sc.textFile("...").map(lambda line: line.split(" "))
+
+hashingTF = HashingTF()
+tf = hashingTF.transform(documents)
+{% endhighlight %}
+
+While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes:
+first to compute the IDF vector and second to scale the term frequencies by IDF.
+{% highlight python %}
+from pyspark.mllib.feature import IDF
+
+# ... continue from the previous example
+tf.cache()
+idf = IDF().fit(tf)
+tfidf = idf.transform(tf)
+{% endhighlight %}
+
+MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
+minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
+can be used by passing the `minDocFreq` value to the IDF constructor.
+
+{% highlight python %}
+# ... continue from the previous example
+tf.cache()
+idf = IDF(minDocFreq=2).fit(tf)
+tfidf = idf.transform(tf)
+{% endhighlight %}
+{% highlight python %}
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.feature import Normalizer
+
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+labels = data.map(lambda x: x.label)
+features = data.map(lambda x: x.features)
+
+normalizer1 = Normalizer()
+normalizer2 = Normalizer(p=float("inf"))
+
+# Each sample in data1 will be normalized using $L^2$ norm.
+data1 = labels.zip(normalizer1.transform(features))
+
+# Each sample in data2 will be normalized using $L^\infty$ norm.
+data2 = labels.zip(normalizer2.transform(features))
+{% endhighlight %}
+
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index d31bec3e1bd01..bc914a1899801 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -247,7 +247,7 @@ val modelL1 = svmAlg.run(training)
All of MLlib's methods use Java-friendly types, so you can import and call them there the same
way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the
Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by
-calling `.rdd()` on your `JavaRDD` object. A standalone application example
+calling `.rdd()` on your `JavaRDD` object. A self-contained application example
that is equivalent to the provided example in Scala is given bellow:
{% highlight java %}
@@ -323,9 +323,9 @@ svmAlg.optimizer()
final SVMModel modelL1 = svmAlg.run(training.rdd());
{% endhighlight %}
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
+In order to run the above application, follow the instructions
+provided in the [Self-Contained
+Applications](quick-start.html#self-contained-applications) section of the Spark
quick-start guide. Be sure to also include *spark-mllib* to your build file as
a dependency.
@@ -482,12 +482,6 @@ public class LinearRegression {
}
}
{% endhighlight %}
-
-In order to run the above standalone application, follow the instructions
-provided in the [Standalone
-Applications](quick-start.html#standalone-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+In order to run the above application, follow the instructions
+provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
+section of the Spark
+quick-start guide. Be sure to also include *spark-mllib* to your build file as
+a dependency.
+
## Streaming linear regression
When data arrive in a streaming fashion, it is useful to fit regression models online,
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 7f9d4c6563944..d5b044d94fdd7 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel =
return new Tuple2(model.predict(p.features()), p.label());
}
});
-double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() {
+double accuracy = predictionAndLabel.filter(new Function, Boolean>() {
@Override public Boolean call(Tuple2 pl) {
- return pl._1() == pl._2();
+ return pl._1().equals(pl._2());
}
- }).count() / test.count();
+ }).count() / (double) test.count();
{% endhighlight %}
diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md
index c4632413991f1..ca8c29218f52d 100644
--- a/docs/mllib-statistics.md
+++ b/docs/mllib-statistics.md
@@ -197,7 +197,7 @@ print Statistics.corr(data, method="pearson")
## Stratified sampling
-Unlike the other statistics functions, which reside in MLLib, stratified sampling methods,
+Unlike the other statistics functions, which reside in MLlib, stratified sampling methods,
`sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified
sampling, the keys can be thought of as a label and the value as a specific attribute. For example
the key can be man or woman, or document ids, and the respective values can be the list of ages
@@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) {
{% endhighlight %}
+
+[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to
+run Pearson's chi-squared tests. The following example demonstrates how to run and interpret
+hypothesis tests.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.mllib.linalg import Vectors, Matrices
+from pyspark.mllib.regresssion import LabeledPoint
+from pyspark.mllib.stat import Statistics
+
+sc = SparkContext()
+
+vec = Vectors.dense(...) # a vector composed of the frequencies of events
+
+# compute the goodness of fit. If a second vector to test against is not supplied as a parameter,
+# the test runs against a uniform distribution.
+goodnessOfFitTestResult = Statistics.chiSqTest(vec)
+print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom,
+ # test statistic, the method used, and the null hypothesis.
+
+mat = Matrices.dense(...) # a contingency matrix
+
+# conduct Pearson's independence test on the input contingency matrix
+independenceTestResult = Statistics.chiSqTest(mat)
+print independenceTestResult # summary of the test including the p-value, degrees of freedom...
+
+obs = sc.parallelize(...) # LabeledPoint(feature, label) .
+
+# The contingency table is constructed from an RDD of LabeledPoint and used to conduct
+# the independence test. Returns an array containing the ChiSquaredTestResult for every feature
+# against the label.
+featureTestResults = Statistics.chiSqTest(obs)
+
+for i, result in enumerate(featureTestResults):
+ print "Column $d:" % (i + 1)
+ print result
+{% endhighlight %}
+
+
## Random data generation
diff --git a/docs/monitoring.md b/docs/monitoring.md
index d07ec4a57a2cc..f32cdef240d31 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -77,6 +77,13 @@ follows:
one implementation, provided by Spark, which looks for application logs stored in the
file system.
+
+
spark.history.fs.logDirectory
+
file:/tmp/spark-events
+
+ Directory that contains application event logs to be loaded by the history server
+
+
spark.history.fs.updateInterval
10
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 8e8cc1dd983f8..c60de6e970531 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/
how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object
that contains information about your application.
+Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one.
+
{% highlight scala %}
val conf = new SparkConf().setAppName(appName).setMaster(master)
new SparkContext(conf)
@@ -211,17 +213,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes,
It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the
enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To
-use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`:
+use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`:
{% highlight bash %}
-$ PYSPARK_PYTHON=ipython ./bin/pyspark
+$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark
{% endhighlight %}
-You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch
+You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch
the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support:
{% highlight bash %}
-$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark
+$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark
{% endhighlight %}
@@ -1131,7 +1133,7 @@ method. The code below shows this:
{% highlight scala %}
scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
-broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c)
+broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
scala> broadcastVar.value
res0: Array[Int] = Array(1, 2, 3)
@@ -1304,6 +1306,12 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam())
+For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator
+will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware
+of that each task's update may be applied more than once if tasks or job stages are re-executed.
+
+
+
# Deploying to a Cluster
The [application submission guide](submitting-applications.html) describes how to submit applications to a cluster.
diff --git a/docs/quick-start.md b/docs/quick-start.md
index 23313d8aa6152..bf643bb70e153 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -8,7 +8,7 @@ title: Quick Start
This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's
interactive shell (in Python or Scala),
-then show how to write standalone applications in Java, Scala, and Python.
+then show how to write applications in Java, Scala, and Python.
See the [programming guide](programming-guide.html) for a more complete reference.
To follow along with this guide, first download a packaged release of Spark from the
@@ -215,8 +215,8 @@ a cluster, as described in the [programming guide](programming-guide.html#initia
-# Standalone Applications
-Now say we wanted to write a standalone application using the Spark API. We will walk through a
+# Self-Contained Applications
+Now say we wanted to write a self-contained application using the Spark API. We will walk through a
simple application in both Scala (with SBT), Java (with Maven), and Python.
@@ -244,6 +244,9 @@ object SimpleApp {
}
{% endhighlight %}
+Note that applications should define a `main()` method instead of extending `scala.App`.
+Subclasses of `scala.App` may not work correctly.
+
This program just counts the number of lines containing 'a' and the number containing 'b' in the
Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is
installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext,
@@ -387,7 +390,7 @@ Lines with a: 46, Lines with b: 23
-Now we will show how to write a standalone application using the Python API (PySpark).
+Now we will show how to write an application using the Python API (PySpark).
As an example, we'll create a simple Spark application, `SimpleApp.py`:
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 695813a2ba881..dfe2db4b3fce8 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -4,7 +4,7 @@ title: Running Spark on YARN
---
Support for running on [YARN (Hadoop
-NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html)
+NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html)
was added to Spark in version 0.6.0, and improved in subsequent releases.
# Preparations
@@ -39,7 +39,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
spark.yarn.preserve.staging.files
false
- Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather then delete them.
+ Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them.
@@ -159,7 +159,7 @@ For example:
lib/spark-examples*.jar \
10
-The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs.
+The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs.
To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell:
@@ -181,7 +181,7 @@ In YARN terminology, executors and application masters run inside "containers".
yarn logs -applicationId
-will print out the contents of all log files from all containers from the given application.
+will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`).
When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID.
diff --git a/docs/security.md b/docs/security.md
index ec0523184d665..1e206a139fb72 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can
* For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret.
* For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications.
-* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.*
## Web UI
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 368c3d0008b07..5500da83b2b66 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -14,7 +14,7 @@ title: Spark SQL Programming Guide
Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using
Spark. At the core of this component is a new type of RDD,
[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of
-[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects, along with
+[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with
a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table
in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io)
file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
@@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
spark.sql.parquet.cacheMetadata
-
false
+
true
Turns on caching of Parquet schema metadata. Can speed up querying of static data.
spark.sql.parquet.compression.codec
-
snappy
+
gzip
Sets the compression codec use when writing Parquet files. Acceptable values include:
uncompressed, snappy, gzip, lzo.
+
+
spark.sql.hive.convertMetastoreParquet
+
true
+
+ When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in
+ support.
+
+
## JSON Datasets
@@ -720,7 +728,7 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/).
However, since Hive has a large number of dependencies, it is not included in the default Spark assembly.
-In order to use Hive you must first run "`sbt/sbt -Phive assembly/assembly`" (or use `-Phive` for maven).
+Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build.
This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present
on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries
(SerDes) in order to access data stored in Hive.
@@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
Property Name
Default
Meaning
spark.sql.inMemoryColumnarStorage.compressed
-
false
+
true
When set to true Spark SQL will automatically select a compression codec for each column based
on statistics of the data.
@@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
spark.sql.inMemoryColumnarStorage.batchSize
-
1000
+
10000
Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization
and compression, but risk OOMs when caching data.
@@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar
Property Name
Default
Meaning
spark.sql.autoBroadcastJoinThreshold
-
10000
+
10485760 (10 MB)
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
@@ -1051,7 +1059,6 @@ in Hive deployments.
**Major Hive Features**
-* Spark SQL does not currently support inserting to tables using dynamic partitioning.
* Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL
doesn't support buckets yet.
@@ -1215,7 +1222,7 @@ import org.apache.spark.sql._
DecimalType
-
scala.math.sql.BigDecimal
+
scala.math.BigDecimal
DecimalType
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 5c21e912ea160..44a1f3ad7560b 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -68,7 +68,9 @@ import org.apache.spark._
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-// Create a local StreamingContext with two working thread and batch interval of 1 second
+// Create a local StreamingContext with two working thread and batch interval of 1 second.
+// The master requires 2 cores to prevent from a starvation scenario.
+
val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount")
val ssc = new StreamingContext(conf, Seconds(1))
{% endhighlight %}
@@ -212,6 +214,67 @@ The complete code can be found in the Spark Streaming example
[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
+
+
+First, we import StreamingContext, which is the main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+# Create a local StreamingContext with two working thread and batch interval of 1 second
+sc = SparkContext("local[2]", "NetworkWordCount")
+ssc = StreamingContext(sc, 1)
+{% endhighlight %}
+
+Using this context, we can create a DStream that represents streaming data from a TCP
+source hostname, e.g. `localhost`, and port, e.g. `9999`
+
+{% highlight python %}
+# Create a DStream that will connect to hostname:port, like localhost:9999
+lines = ssc.socketTextStream("localhost", 9999)
+{% endhighlight %}
+
+This `lines` DStream represents the stream of data that will be received from the data
+server. Each record in this DStream is a line of text. Next, we want to split the lines by
+space into words.
+
+{% highlight python %}
+# Split each line into words
+words = lines.flatMap(lambda line: line.split(" "))
+{% endhighlight %}
+
+`flatMap` is a one-to-many DStream operation that creates a new DStream by
+generating multiple new records from each record in the source DStream. In this case,
+each line will be split into multiple words and the stream of words is represented as the
+`words` DStream. Next, we want to count these words.
+
+{% highlight python %}
+# Count each word in each batch
+pairs = words.map(lambda word: (word, 1))
+wordCounts = pairs.reduceByKey(lambda x, y: x + y)
+
+# Print the first ten elements of each RDD generated in this DStream to the console
+wordCounts.pprint()
+{% endhighlight %}
+
+The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word,
+1)` pairs, which is then reduced to get the frequency of words in each batch of data.
+Finally, `wordCounts.pprint()` will print a few of the counts generated every second.
+
+Note that when these lines are executed, Spark Streaming only sets up the computation it
+will perform when it is started, and no real processing has started yet. To start the processing
+after all the transformations have been setup, we finally call
+
+{% highlight python %}
+ssc.start() # Start the computation
+ssc.awaitTermination() # Wait for the computation to terminate
+{% endhighlight %}
+
+The complete code can be found in the Spark Streaming example
+[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py).
+
+
+
+A [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) object can be created from a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+sc = SparkContext(master, appName)
+ssc = StreamingContext(sc, 1)
+{% endhighlight %}
+
+The `appName` parameter is a name for your application to show on the cluster UI.
+`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls),
+or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster,
+you will not want to hardcode `master` in the program,
+but rather [launch the application with `spark-submit`](submitting-applications.html) and
+receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming
+in-process (detects the number of cores in the local system).
+
+The batch interval must be set based on the latency requirements of your application
+and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size)
+section for more details.
+
After a context is defined, you have to do the follow steps.
+
1. Define the input sources.
1. Setup the streaming computations.
1. Start the receiving and procesing of data using `streamingContext.start()`.
@@ -461,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver](
A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are:
-##### Points to remember:
+##### Points to remember
{:.no_toc}
-- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them.
-- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data.
-
+- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them.
+- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor.
+Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally.
+See [Spark Properties] (configuration.html#spark-properties.html).
+
### Basic Sources
{:.no_toc}
@@ -483,6 +612,9 @@ methods for creating DStreams from files and Akka actors as input sources.
Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that
@@ -494,7 +626,7 @@ methods for creating DStreams from files and Akka actors as input sources.
For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores.
-- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details.
+- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details.
- **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream.
@@ -684,13 +816,30 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction);
{% endhighlight %}
+
+
+
+{% highlight python %}
+def updateFunction(newValues, runningCount):
+ if runningCount is None:
+ runningCount = 0
+ return sum(newValues, runningCount) # add the new values with the previous running count to get the new count
+{% endhighlight %}
+
+This is applied on a DStream containing words (say, the `pairs` DStream containing `(word,
+1)` pairs in the [earlier example](#a-quick-example)).
+
+{% highlight python %}
+runningCounts = pairs.updateStateByKey(updateFunction)
+{% endhighlight %}
+
The update function will be called for each word, with `newValues` having a sequence of 1's (from
the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
Scala code, take a look at the example
-[StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala).
+[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py).
#### Transform Operation
{:.no_toc}
@@ -732,6 +881,15 @@ JavaPairDStream cleanedDStream = wordCounts.transform(
});
{% endhighlight %}
+
+
+
+{% highlight python %}
+spamInfoRDD = sc.pickleFile(...) # RDD containing spam information
+
+# join data stream with spam information to do data cleaning
+cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(...))
+{% endhighlight %}
@@ -793,6 +951,14 @@ Function2 reduceFunc = new Function2 windowedWordCounts = pairs.reduceByKeyAndWindow(reduceFunc, new Duration(30000), new Duration(10000));
{% endhighlight %}
+
+
+
+{% highlight python %}
+# Reduce last 30 seconds of data, every 10 seconds
+windowedWordCounts = pairs.reduceByKeyAndWindow(lambda x, y: x + y, lambda x, y: x - y, 30, 10)
+{% endhighlight %}
+
@@ -860,6 +1026,7 @@ see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream)
and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions).
For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html)
and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html).
+For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream)
***
@@ -872,9 +1039,12 @@ Currently, the following output operations are defined:
Output Operation
Meaning
-
print()
+
print()
Prints first ten elements of every batch of data in a DStream on the driver.
- This is useful for development and debugging.
+ This is useful for development and debugging.
+
+ PS: called pprint() in Python)
+
saveAsObjectFiles(prefix, [suffix])
@@ -915,17 +1085,41 @@ For this purpose, a developer may inadvertantly try creating a connection object
the Spark driver, but try to use it in a Spark worker to save records in the RDDs.
For example (in Scala),
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
val connection = createNewConnection() // executed at the driver
rdd.foreach(record => {
connection.send(record) // executed at the worker
})
})
+{% endhighlight %}
- This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker.
+
+
+ This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker.
- However, this can lead to another common mistake - creating a new connection for every record. For example,
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
rdd.foreach(record => {
val connection = createNewConnection()
@@ -933,9 +1127,28 @@ For example (in Scala),
connection.close()
})
})
+{% endhighlight %}
- Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
+
+
+ Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection.
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
rdd.foreachPartition(partitionOfRecords => {
val connection = createNewConnection()
@@ -943,13 +1156,31 @@ For example (in Scala),
connection.close()
})
})
+{% endhighlight %}
+
+
+
+{% highlight python %}
+def sendPartition(iter):
+ connection = createNewConnection()
+ for record in iter:
+ connection.send(record)
+ connection.close()
+
+dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition))
+{% endhighlight %}
+
+
- This amortizes the connection creation overheads over many records.
+ This amortizes the connection creation overheads over many records.
- Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches.
One can maintain a static pool of connection objects than can be reused as
RDDs of multiple batches are pushed to the external system, thus further reducing the overheads.
-
+
+
+
+{% highlight scala %}
dstream.foreachRDD(rdd => {
rdd.foreachPartition(partitionOfRecords => {
// ConnectionPool is a static, lazily initialized pool of connections
@@ -958,8 +1189,25 @@ For example (in Scala),
ConnectionPool.returnConnection(connection) // return to the pool for future reuse
})
})
+{% endhighlight %}
+
- Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems.
+
+{% highlight python %}
+def sendPartition(iter):
+ # ConnectionPool is a static, lazily initialized pool of connections
+ connection = ConnectionPool.getConnection()
+ for record in iter:
+ connection.send(record)
+ # return to the pool for future reuse
+ ConnectionPool.returnConnection(connection)
+
+dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition))
+{% endhighlight %}
+
+
+
+Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems.
##### Other points to remember:
@@ -1376,6 +1624,44 @@ You can also explicitly create a `JavaStreamingContext` from the checkpoint data
the computation by using `new JavaStreamingContext(checkpointDirectory)`.
+
+
+This behavior is made simple by using `StreamingContext.getOrCreate`. This is used as follows.
+
+{% highlight python %}
+# Function to create and setup a new StreamingContext
+def functionToCreateContext():
+ sc = SparkContext(...) # new context
+ ssc = new StreamingContext(...)
+ lines = ssc.socketTextStream(...) # create DStreams
+ ...
+ ssc.checkpoint(checkpointDirectory) # set checkpoint directory
+ return ssc
+
+# Get StreamingContext from checkpoint data or create a new one
+context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext)
+
+# Do additional setup on context that needs to be done,
+# irrespective of whether it is being started or restarted
+context. ...
+
+# Start the context
+context.start()
+context.awaitTermination()
+{% endhighlight %}
+
+If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data.
+If the directory does not exist (i.e., running for the first time),
+then the function `functionToCreateContext` will be called to create a new
+context and set up the DStreams. See the Python example
+[recoverable_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming/recoverable_network_wordcount.py).
+This example appends the word counts of network data into a file.
+
+You can also explicitly create a `StreamingContext` from the checkpoint data and start the
+ computation by using `StreamingContext.getOrCreate(checkpointDirectory, None)`.
+
+
+
**Note**: If Spark Streaming and/or the Spark Streaming program is recompiled,
@@ -1572,7 +1858,11 @@ package and renamed for better clarity.
[TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html),
[ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and
[MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html)
+ - Python docs
+ * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext)
+ * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream)
* More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming)
and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming)
+ and [Python] ({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming)
* [Paper](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) and [video](http://youtu.be/g171ndOHgJ0) describing Spark Streaming.
diff --git a/docs/tuning.md b/docs/tuning.md
index 8fb2a0433b1a8..9b5c9adac6a4f 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -47,24 +47,11 @@ registration requirement, but we recommend trying it in any network-intensive ap
Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered
in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library.
-To register your own custom classes with Kryo, create a public class that extends
-[`org.apache.spark.serializer.KryoRegistrator`](api/scala/index.html#org.apache.spark.serializer.KryoRegistrator) and set the
-`spark.kryo.registrator` config property to point to it, as follows:
+To register your own custom classes with Kryo, use the `registerKryoClasses` method.
{% highlight scala %}
-import com.esotericsoftware.kryo.Kryo
-import org.apache.spark.serializer.KryoRegistrator
-
-class MyRegistrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[MyClass1])
- kryo.register(classOf[MyClass2])
- }
-}
-
val conf = new SparkConf().setMaster(...).setAppName(...)
-conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-conf.set("spark.kryo.registrator", "mypackage.MyRegistrator")
+conf.registerKryoClasses(Seq(classOf[MyClass1], classOf[MyClass2]))
val sc = new SparkContext(conf)
{% endhighlight %}
diff --git a/ec2/spark-ec2 b/ec2/spark-ec2
index 31f9771223e51..4aa908242eeaa 100755
--- a/ec2/spark-ec2
+++ b/ec2/spark-ec2
@@ -18,5 +18,9 @@
# limitations under the License.
#
-cd "`dirname $0`"
-PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@"
+# Preserve the user's CWD so that relative paths are passed correctly to
+#+ the underlying Python script.
+SPARK_EC2_DIR="$(dirname $0)"
+
+PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \
+ python "${SPARK_EC2_DIR}/spark_ec2.py" "$@"
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 941dfb988b9fb..742c7765e728e 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -32,6 +32,7 @@
import tempfile
import time
import urllib2
+import warnings
from optparse import OptionParser
from sys import stderr
import boto
@@ -39,9 +40,11 @@
from boto import ec2
DEFAULT_SPARK_VERSION = "1.1.0"
+SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
+MESOS_SPARK_EC2_BRANCH = "v4"
# A URL prefix from which to fetch AMI information
-AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list"
+AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH)
class UsageError(Exception):
@@ -61,8 +64,8 @@ def parse_args():
"-s", "--slaves", type="int", default=1,
help="Number of slaves to launch (default: %default)")
parser.add_option(
- "-w", "--wait", type="int", default=120,
- help="Seconds to wait for nodes to start (default: %default)")
+ "-w", "--wait", type="int",
+ help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start")
parser.add_option(
"-k", "--key-pair",
help="Key pair to use on instances")
@@ -83,7 +86,7 @@ def parse_args():
"-z", "--zone", default="",
help="Availability zone to launch instances in, or 'all' to spread " +
"slaves across multiple (an additional $0.01/Gb for bandwidth" +
- "between zones applies)")
+ "between zones applies) (default: a single zone chosen at random)")
parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use")
parser.add_option(
"-v", "--spark-version", default=DEFAULT_SPARK_VERSION,
@@ -135,7 +138,7 @@ def parse_args():
help="The SSH user you want to connect as (default: %default)")
parser.add_option(
"--delete-groups", action="store_true", default=False,
- help="When destroying a cluster, delete the security groups that were created.")
+ help="When destroying a cluster, delete the security groups that were created")
parser.add_option(
"--use-existing-master", action="store_true", default=False,
help="Launch fresh slaves, but use an existing stopped master if possible")
@@ -149,9 +152,6 @@ def parse_args():
parser.add_option(
"--user-data", type="string", default="",
help="Path to a user-data file (most AMI's interpret this as an initialization script)")
- parser.add_option(
- "--security-group-prefix", type="string", default=None,
- help="Use this prefix for the security group rather than the cluster name.")
parser.add_option(
"--authorized-address", type="string", default="0.0.0.0/0",
help="Address to authorize on created security groups (default: %default)")
@@ -195,18 +195,6 @@ def get_or_make_group(conn, name):
return conn.create_security_group(name, "Spark EC2 group")
-# Wait for a set of launched instances to exit the "pending" state
-# (i.e. either to start running or to fail and be terminated)
-def wait_for_instances(conn, instances):
- while True:
- for i in instances:
- i.update()
- if len([i for i in instances if i.state == 'pending']) > 0:
- time.sleep(5)
- else:
- return
-
-
# Check whether a given EC2 instance object is in a state we consider active,
# i.e. not terminating or terminated. We count both stopping and stopped as
# active since we can restart stopped clusters.
@@ -314,12 +302,8 @@ def launch_cluster(conn, opts, cluster_name):
user_data_content = user_data_file.read()
print "Setting up security groups..."
- if opts.security_group_prefix is None:
- master_group = get_or_make_group(conn, cluster_name + "-master")
- slave_group = get_or_make_group(conn, cluster_name + "-slaves")
- else:
- master_group = get_or_make_group(conn, opts.security_group_prefix + "-master")
- slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves")
+ master_group = get_or_make_group(conn, cluster_name + "-master")
+ slave_group = get_or_make_group(conn, cluster_name + "-slaves")
authorized_address = opts.authorized_address
if master_group.rules == []: # Group was just now created
master_group.authorize(src_group=master_group)
@@ -344,11 +328,12 @@ def launch_cluster(conn, opts, cluster_name):
slave_group.authorize('tcp', 60060, 60060, authorized_address)
slave_group.authorize('tcp', 60075, 60075, authorized_address)
- # Check if instances are already running with the cluster name
+ # Check if instances are already running in our groups
existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
die_on_error=False)
if existing_slaves or (existing_masters and not opts.use_existing_master):
- print >> stderr, ("ERROR: There are already instances for name: %s " % cluster_name)
+ print >> stderr, ("ERROR: There are already instances running in " +
+ "group %s or %s" % (master_group.name, slave_group.name))
sys.exit(1)
# Figure out Spark AMI
@@ -422,13 +407,9 @@ def launch_cluster(conn, opts, cluster_name):
for r in reqs:
id_to_req[r.id] = r
active_instance_ids = []
- outstanding_request_ids = []
for i in my_req_ids:
- if i in id_to_req:
- if id_to_req[i].state == "active":
- active_instance_ids.append(id_to_req[i].instance_id)
- else:
- outstanding_request_ids.append(i)
+ if i in id_to_req and id_to_req[i].state == "active":
+ active_instance_ids.append(id_to_req[i].instance_id)
if len(active_instance_ids) == opts.slaves:
print "All %d slaves granted" % opts.slaves
reservations = conn.get_all_instances(active_instance_ids)
@@ -437,8 +418,8 @@ def launch_cluster(conn, opts, cluster_name):
slave_nodes += r.instances
break
else:
- print "%d of %d slaves granted, waiting longer for request ids including %s" % (
- len(active_instance_ids), opts.slaves, outstanding_request_ids[0:10])
+ print "%d of %d slaves granted, waiting longer" % (
+ len(active_instance_ids), opts.slaves)
except:
print "Canceling spot instance requests"
conn.cancel_spot_instance_requests(my_req_ids)
@@ -497,59 +478,34 @@ def launch_cluster(conn, opts, cluster_name):
# Give the instances descriptive names
for master in master_nodes:
- name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)
- tag_instance(master, name)
-
+ master.add_tag(
+ key='Name',
+ value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
for slave in slave_nodes:
- name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)
- tag_instance(slave, name)
+ slave.add_tag(
+ key='Name',
+ value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
# Return all the instances
return (master_nodes, slave_nodes)
-def tag_instance(instance, name):
- for i in range(0, 5):
- try:
- instance.add_tag(key='Name', value=name)
- break
- except:
- print "Failed attempt %i of 5 to tag %s" % ((i + 1), name)
- if i == 5:
- raise "Error - failed max attempts to add name tag"
- time.sleep(5)
-
# Get the EC2 instances in an existing cluster if available.
# Returns a tuple of lists of EC2 instance objects for the masters and slaves
def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
print "Searching for existing cluster " + cluster_name + "..."
- # Search all the spot instance requests, and copy any tags from the spot
- # instance request to the cluster.
- spot_instance_requests = conn.get_all_spot_instance_requests()
- for req in spot_instance_requests:
- if req.state != u'active':
- continue
- name = req.tags.get(u'Name', "")
- if name.startswith(cluster_name):
- reservations = conn.get_all_instances(instance_ids=[req.instance_id])
- for res in reservations:
- active = [i for i in res.instances if is_active(i)]
- for instance in active:
- if instance.tags.get(u'Name') is None:
- tag_instance(instance, name)
- # Now proceed to detect master and slaves instances.
reservations = conn.get_all_instances()
master_nodes = []
slave_nodes = []
for res in reservations:
active = [i for i in res.instances if is_active(i)]
for inst in active:
- name = inst.tags.get(u'Name', "")
- if name.startswith(cluster_name + "-master"):
+ group_names = [g.name for g in inst.groups]
+ if group_names == [cluster_name + "-master"]:
master_nodes.append(inst)
- elif name.startswith(cluster_name + "-slave"):
+ elif group_names == [cluster_name + "-slaves"]:
slave_nodes.append(inst)
if any((master_nodes, slave_nodes)):
print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))
@@ -557,12 +513,12 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
return (master_nodes, slave_nodes)
else:
if master_nodes == [] and slave_nodes != []:
- print >> sys.stderr, "ERROR: Could not find master in with name " + \
- cluster_name + "-master"
+ print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master"
else:
print >> sys.stderr, "ERROR: Could not find any existing cluster"
sys.exit(1)
+
# Deploy configuration files and run setup scripts on a newly launched
# or started EC2 cluster.
@@ -594,10 +550,23 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
- ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v3")
+ ssh(
+ host=master,
+ opts=opts,
+ command="rm -rf spark-ec2"
+ + " && "
+ + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH)
+ )
print "Deploying files to master..."
- deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules)
+ deploy_files(
+ conn=conn,
+ root_dir=SPARK_EC2_DIR + "/" + "deploy.generic",
+ opts=opts,
+ master_nodes=master_nodes,
+ slave_nodes=slave_nodes,
+ modules=modules
+ )
print "Running setup on master..."
setup_spark_cluster(master, opts)
@@ -619,14 +588,64 @@ def setup_spark_cluster(master, opts):
print "Ganglia started at http://%s:5080/ganglia" % master
-# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up
-def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes):
- print "Waiting for instances to start up..."
- time.sleep(5)
- wait_for_instances(conn, master_nodes)
- wait_for_instances(conn, slave_nodes)
- print "Waiting %d more seconds..." % wait_secs
- time.sleep(wait_secs)
+def is_ssh_available(host, opts):
+ "Checks if SSH is available on the host."
+ try:
+ with open(os.devnull, 'w') as devnull:
+ ret = subprocess.check_call(
+ ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3',
+ '%s@%s' % (opts.user, host), stringify_command('true')],
+ stdout=devnull,
+ stderr=devnull
+ )
+ return ret == 0
+ except subprocess.CalledProcessError as e:
+ return False
+
+
+def is_cluster_ssh_available(cluster_instances, opts):
+ for i in cluster_instances:
+ if not is_ssh_available(host=i.ip_address, opts=opts):
+ return False
+ else:
+ return True
+
+
+def wait_for_cluster_state(cluster_instances, cluster_state, opts):
+ """
+ cluster_instances: a list of boto.ec2.instance.Instance
+ cluster_state: a string representing the desired state of all the instances in the cluster
+ value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as
+ 'running', 'terminated', etc.
+ (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250)
+ """
+ sys.stdout.write(
+ "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state)
+ )
+ sys.stdout.flush()
+
+ num_attempts = 0
+
+ while True:
+ time.sleep(3 * num_attempts)
+
+ for i in cluster_instances:
+ s = i.update() # capture output to suppress print to screen in newer versions of boto
+
+ if cluster_state == 'ssh-ready':
+ if all(i.state == 'running' for i in cluster_instances) and \
+ is_cluster_ssh_available(cluster_instances, opts):
+ break
+ else:
+ if all(i.state == cluster_state for i in cluster_instances):
+ break
+
+ num_attempts += 1
+
+ sys.stdout.write(".")
+ sys.stdout.flush()
+
+ sys.stdout.write("\n")
# Get number of local disks available for a given EC2 instance type.
@@ -684,6 +703,8 @@ def get_num_disks(instance_type):
# cluster (e.g. lists of masters and slaves). Files are only deployed to
# the first master instance in the cluster, and we expect the setup
# script to be run on that instance to copy them to other nodes.
+#
+# root_dir should be an absolute path to the directory with the files we want to deploy.
def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules):
active_master = master_nodes[0].public_dns_name
@@ -868,6 +889,16 @@ def real_main():
(opts, action, cluster_name) = parse_args()
# Input parameter validation
+ if opts.wait is not None:
+ # NOTE: DeprecationWarnings are silent in 2.7+ by default.
+ # To show them, run Python with the -Wdefault switch.
+ # See: https://docs.python.org/3.5/whatsnew/2.7.html
+ warnings.warn(
+ "This option is deprecated and has no effect. "
+ "spark-ec2 automatically waits as long as necessary for clusters to startup.",
+ DeprecationWarning
+ )
+
if opts.ebs_vol_num > 8:
print >> stderr, "ebs-vol-num cannot be greater than 8"
sys.exit(1)
@@ -890,7 +921,11 @@ def real_main():
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
else:
(master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name)
- wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes)
+ wait_for_cluster_state(
+ cluster_instances=(master_nodes + slave_nodes),
+ cluster_state='ssh-ready',
+ opts=opts
+ )
setup_cluster(conn, master_nodes, slave_nodes, opts, True)
elif action == "destroy":
@@ -914,12 +949,12 @@ def real_main():
# Delete security groups as well
if opts.delete_groups:
print "Deleting security groups (this will take some time)..."
- if opts.security_group_prefix is None:
- group_names = [cluster_name + "-master", cluster_name + "-slaves"]
- else:
- group_names = [opts.security_group_prefix + "-master",
- opts.security_group_prefix + "-slaves"]
-
+ group_names = [cluster_name + "-master", cluster_name + "-slaves"]
+ wait_for_cluster_state(
+ cluster_instances=(master_nodes + slave_nodes),
+ cluster_state='terminated',
+ opts=opts
+ )
attempt = 1
while attempt <= 3:
print "Attempt %d" % attempt
@@ -1019,7 +1054,11 @@ def real_main():
for inst in master_nodes:
if inst.state not in ["shutting-down", "terminated"]:
inst.start()
- wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes)
+ wait_for_cluster_state(
+ cluster_instances=(master_nodes + slave_nodes),
+ cluster_state='ssh-ready',
+ opts=opts
+ )
setup_cluster(conn, master_nodes, slave_nodes, opts, False)
else:
diff --git a/examples/pom.xml b/examples/pom.xml
index eb49a0e5af22d..8713230e1e8ed 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -34,24 +34,6 @@
Spark Project Exampleshttp://spark.apache.org/
-
-
- kinesis-asl
-
-
- org.apache.spark
- spark-streaming-kinesis-asl_${scala.binary.version}
- ${project.version}
-
-
- org.apache.httpcomponents
- httpclient
- ${commons.httpclient.version}
-
-
-
-
-
@@ -102,12 +84,12 @@
org.apache.spark
- spark-streaming-kafka_${scala.binary.version}
+ spark-streaming-flume_${scala.binary.version}${project.version}org.apache.spark
- spark-streaming-flume_${scala.binary.version}
+ spark-streaming-mqtt_${scala.binary.version}${project.version}
@@ -116,45 +98,151 @@
${project.version}
- org.apache.spark
- spark-streaming-mqtt_${scala.binary.version}
- ${project.version}
+ org.eclipse.jetty
+ jetty-server
-
- org.apache.hbase
- hbase
- ${hbase.version}
-
-
- asm
- asm
-
-
- org.jboss.netty
- netty
-
-
+
+ org.apache.hbase
+ hbase-testing-util
+ ${hbase.version}
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+ org.jruby
+ jruby-complete
+
+
+
+
+ org.apache.hbase
+ hbase-protocol
+ ${hbase.version}
+
+
+ org.apache.hbase
+ hbase-common
+ ${hbase.version}
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+
+
+ org.apache.hbase
+ hbase-client
+ ${hbase.version}
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+ io.nettynetty
-
-
- commons-logging
- commons-logging
-
-
- org.jruby
- jruby-complete
-
-
-
+
+
+
+
+ org.apache.hbase
+ hbase-server
+ ${hbase.version}
+
+
+ org.apache.hadoop
+ hadoop-core
+
+
+ org.apache.hadoop
+ hadoop-client
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-jobclient
+
+
+ org.apache.hadoop
+ hadoop-mapreduce-client-core
+
+
+ org.apache.hadoop
+ hadoop-auth
+
+
+
+ org.apache.hbase
+ hbase-annotations
+
+
+ org.apache.hadoop
+ hadoop-annotations
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+
+
+ org.apache.hbase
+ hbase-hadoop1-compat
+
+
+ org.apache.commons
+ commons-math
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+ com.sun.jersey
+ jersey-server
+
+
+ com.sun.jersey
+ jersey-core
+
+
+ com.sun.jersey
+ jersey-json
+
+
+
+ commons-io
+ commons-io
+
+
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ ${hbase.version}
+
+
+ org.apache.hbase
+ hbase-hadoop-compat
+ ${hbase.version}
+ test-jar
+ test
+
- org.eclipse.jetty
- jetty-server
+ org.apache.commons
+ commons-math3com.twitteralgebird-core_${scala.binary.version}
- 0.1.11
+ 0.8.1org.scalatest
@@ -268,6 +356,10 @@
com.google.common.base.Optional**
+
+ org.apache.commons.math3
+ org.spark-project.commons.math3
+
@@ -284,4 +376,83 @@
+
+
+ kinesis-asl
+
+
+ org.apache.spark
+ spark-streaming-kinesis-asl_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.httpcomponents
+ httpclient
+ ${commons.httpclient.version}
+
+
+
+
+ hbase-hadoop2
+
+
+ hbase.profile
+ hadoop2
+
+
+
+ 0.98.7-hadoop2
+
+
+
+ hbase-hadoop1
+
+
+ !hbase.profile
+
+
+
+ 0.98.7-hadoop1
+
+
+
+
+ scala-2.10
+
+ !scala-2.11
+
+
+
+ org.apache.spark
+ spark-streaming-kafka_${scala.binary.version}
+ ${project.version}
+
+
+
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-scala-sources
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
similarity index 100%
rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
rename to examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
similarity index 100%
rename from examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
rename to examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
index 6c177de359b60..31a79ddd3fff1 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java
@@ -30,12 +30,25 @@
/**
* Logistic regression based classification.
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
public final class JavaHdfsLR {
private static final int D = 10; // Number of dimensions
private static final Random rand = new Random(42);
+ static void showWarning() {
+ String warning = "WARN: This is a naive implementation of Logistic Regression " +
+ "and is given as an example!\n" +
+ "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " +
+ "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " +
+ "for more conventional use.";
+ System.err.println(warning);
+ }
+
static class DataPoint implements Serializable {
DataPoint(double[] x, double y) {
this.x = x;
@@ -109,6 +122,8 @@ public static void main(String[] args) {
System.exit(1);
}
+ showWarning();
+
SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD lines = sc.textFile(args[0]);
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
index c22506491fbff..a5db8accdf138 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
@@ -45,10 +45,21 @@
* URL neighbor URL
* ...
* where URL and their neighbors are separated by space(s).
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to org.apache.spark.graphx.lib.PageRank
*/
public final class JavaPageRank {
private static final Pattern SPACES = Pattern.compile("\\s+");
+ static void showWarning() {
+ String warning = "WARN: This is a naive implementation of PageRank " +
+ "and is given as an example! \n" +
+ "Please use the PageRank implementation found in " +
+ "org.apache.spark.graphx.lib.PageRank for more conventional use.";
+ System.err.println(warning);
+ }
+
private static class Sum implements Function2 {
@Override
public Double call(Double a, Double b) {
@@ -62,6 +73,8 @@ public static void main(String[] args) throws Exception {
System.exit(1);
}
+ showWarning();
+
SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank");
JavaSparkContext ctx = new JavaSparkContext(sparkConf);
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java
new file mode 100644
index 0000000000000..e68ec74c3ed54
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkJobInfo;
+import org.apache.spark.SparkStageInfo;
+import org.apache.spark.api.java.JavaFutureAction;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Example of using Spark's status APIs from Java.
+ */
+public final class JavaStatusTrackerDemo {
+
+ public static final String APP_NAME = "JavaStatusAPIDemo";
+
+ public static final class IdentityWithDelay implements Function {
+ @Override
+ public T call(T x) throws Exception {
+ Thread.sleep(2 * 1000); // 2 seconds
+ return x;
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ SparkConf sparkConf = new SparkConf().setAppName(APP_NAME);
+ final JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ // Example of implementing a progress reporter for a simple job.
+ JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map(
+ new IdentityWithDelay());
+ JavaFutureAction> jobFuture = rdd.collectAsync();
+ while (!jobFuture.isDone()) {
+ Thread.sleep(1000); // 1 second
+ List jobIds = jobFuture.jobIds();
+ if (jobIds.isEmpty()) {
+ continue;
+ }
+ int currentJobId = jobIds.get(jobIds.size() - 1);
+ SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId);
+ SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]);
+ System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() +
+ " active, " + stageInfo.numCompletedTasks() + " complete");
+ }
+
+ System.out.println("Job results are: " + jobFuture.get());
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
new file mode 100644
index 0000000000000..22ba68d8c354c
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.Pipeline;
+import org.apache.spark.ml.PipelineModel;
+import org.apache.spark.ml.PipelineStage;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.feature.HashingTF;
+import org.apache.spark.ml.feature.Tokenizer;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.SparkConf;
+
+/**
+ * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
+ * bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of
+ * this example {@link SimpleTextClassificationPipeline}. Run with
+ *
+ */
+public class JavaSimpleTextClassificationPipeline {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ JavaSQLContext jsql = new JavaSQLContext(jsc);
+
+ // Prepare training documents, which are labeled.
+ List localTraining = Lists.newArrayList(
+ new LabeledDocument(0L, "a b c d e spark", 1.0),
+ new LabeledDocument(1L, "b d", 0.0),
+ new LabeledDocument(2L, "spark f g h", 1.0),
+ new LabeledDocument(3L, "hadoop mapreduce", 0.0));
+ JavaSchemaRDD training =
+ jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ Tokenizer tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words");
+ HashingTF hashingTF = new HashingTF()
+ .setNumFeatures(1000)
+ .setInputCol(tokenizer.getOutputCol())
+ .setOutputCol("features");
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.01);
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
+
+ // Fit the pipeline to training documents.
+ PipelineModel model = pipeline.fit(training);
+
+ // Prepare test documents, which are unlabeled.
+ List localTest = Lists.newArrayList(
+ new Document(4L, "spark i j k"),
+ new Document(5L, "l m n"),
+ new Document(6L, "mapreduce spark"),
+ new Document(7L, "apache hadoop"));
+ JavaSchemaRDD test =
+ jsql.applySchema(jsc.parallelize(localTest), Document.class);
+
+ // Make predictions on test documents.
+ model.transform(test).registerAsTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ for (Row r: predictions.collect()) {
+ System.out.println(r);
+ }
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
index 8d381d4e0a943..95a430f1da234 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
@@ -32,7 +32,7 @@
import scala.Tuple2;
/**
- * Example using MLLib ALS from Java.
+ * Example using MLlib ALS from Java.
*/
public final class JavaALS {
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
new file mode 100644
index 0000000000000..4a5ac404ea5ea
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple2;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.tree.GradientBoostedTrees;
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
+import org.apache.spark.mllib.util.MLUtils;
+
+/**
+ * Classification and regression using gradient-boosted decision trees.
+ */
+public final class JavaGradientBoostedTreesRunner {
+
+ private static void usage() {
+ System.err.println("Usage: JavaGradientBoostedTreesRunner " +
+ " ");
+ System.exit(-1);
+ }
+
+ public static void main(String[] args) {
+ String datapath = "data/mllib/sample_libsvm_data.txt";
+ String algo = "Classification";
+ if (args.length >= 1) {
+ datapath = args[0];
+ }
+ if (args.length >= 2) {
+ algo = args[1];
+ }
+ if (args.length > 2) {
+ usage();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
+
+ // Set parameters.
+ // Note: All features are treated as continuous.
+ BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo);
+ boostingStrategy.setNumIterations(10);
+ boostingStrategy.treeStrategy().setMaxDepth(5);
+
+ if (algo.equals("Classification")) {
+ // Compute the number of classes from the data.
+ Integer numClasses = data.map(new Function() {
+ @Override public Double call(LabeledPoint p) {
+ return p.label();
+ }
+ }).countByValue().size();
+ boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses);
+
+ // Train a GradientBoosting model for classification.
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD predictionAndLabel =
+ data.mapToPair(new PairFunction() {
+ @Override public Tuple2 call(LabeledPoint p) {
+ return new Tuple2(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainErr =
+ 1.0 * predictionAndLabel.filter(new Function, Boolean>() {
+ @Override public Boolean call(Tuple2 pl) {
+ return !pl._1().equals(pl._2());
+ }
+ }).count() / data.count();
+ System.out.println("Training error: " + trainErr);
+ System.out.println("Learned classification tree model:\n" + model);
+ } else if (algo.equals("Regression")) {
+ // Train a GradientBoosting model for classification.
+ final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy);
+
+ // Evaluate model on training instances and compute training error
+ JavaPairRDD predictionAndLabel =
+ data.mapToPair(new PairFunction() {
+ @Override public Tuple2 call(LabeledPoint p) {
+ return new Tuple2(model.predict(p.features()), p.label());
+ }
+ });
+ Double trainMSE =
+ predictionAndLabel.map(new Function, Double>() {
+ @Override public Double call(Tuple2 pl) {
+ Double diff = pl._1() - pl._2();
+ return diff * diff;
+ }
+ }).reduce(new Function2() {
+ @Override public Double call(Double a, Double b) {
+ return a + b;
+ }
+ }) / data.count();
+ System.out.println("Training Mean Squared Error: " + trainMSE);
+ System.out.println("Learned regression tree model:\n" + model);
+ } else {
+ usage();
+ }
+
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
index f796123a25727..e575eedeb465c 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
@@ -30,7 +30,7 @@
import org.apache.spark.mllib.linalg.Vectors;
/**
- * Example using MLLib KMeans from Java.
+ * Example using MLlib KMeans from Java.
*/
public final class JavaKMeans {
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
index 5622df5ce03ff..99df259b4e8e6 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
@@ -57,7 +57,7 @@ public class JavaCustomReceiver extends Receiver {
public static void main(String[] args) {
if (args.length < 2) {
- System.err.println("Usage: JavaNetworkWordCount ");
+ System.err.println("Usage: JavaCustomReceiver ");
System.exit(1);
}
@@ -70,7 +70,7 @@ public static void main(String[] args) {
// Create a input stream with the custom receiver on target ip:port and count the
// words in input stream of \n delimited text (eg. generated by 'nc')
JavaReceiverInputDStream lines = ssc.receiverStream(
- new JavaCustomReceiver(args[1], Integer.parseInt(args[2])));
+ new JavaCustomReceiver(args[0], Integer.parseInt(args[1])));
JavaDStream words = lines.flatMap(new FlatMapFunction() {
@Override
public Iterable call(String x) {
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
index 45bcedebb4117..3e9f0f4b8f127 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java
@@ -25,7 +25,7 @@
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.StorageLevels;
-import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
@@ -35,8 +35,9 @@
/**
* Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ *
* Usage: JavaNetworkWordCount
- * and describe the TCP server that Spark Streaming would connect to receive data.
+ * and describe the TCP server that Spark Streaming would connect to receive data.
*
* To run this on your local machine, you need to first run a Netcat server
* `$ nc -lk 9999`
@@ -56,7 +57,7 @@ public static void main(String[] args) {
// Create the context with a 1 second batch size
SparkConf sparkConf = new SparkConf().setAppName("JavaNetworkWordCount");
- JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000));
+ JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
// Create a JavaReceiverInputDStream on target ip:port and count the
// words in input stream of \n delimited text (eg. generated by 'nc')
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
new file mode 100644
index 0000000000000..bceda97f058ea
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.util.Arrays;
+import java.util.regex.Pattern;
+
+import scala.Tuple2;
+import com.google.common.collect.Lists;
+import com.google.common.io.Files;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.Time;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.apache.spark.streaming.api.java.JavaStreamingContextFactory;
+
+/**
+ * Counts words in text encoded with UTF8 received from the network every second.
+ *
+ * Usage: JavaRecoverableNetworkWordCount
+ * and describe the TCP server that Spark Streaming would connect to receive
+ * data. directory to HDFS-compatible file system which checkpoint data
+ * file to which the word counts will be appended
+ *
+ * and must be absolute paths
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ *
+ * `$ nc -lk 9999`
+ *
+ * and run the example as
+ *
+ * `$ ./bin/run-example org.apache.spark.examples.streaming.JavaRecoverableNetworkWordCount \
+ * localhost 9999 ~/checkpoint/ ~/out`
+ *
+ * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create
+ * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if
+ * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
+ * the checkpoint data.
+ *
+ * Refer to the online documentation for more details.
+ */
+public final class JavaRecoverableNetworkWordCount {
+ private static final Pattern SPACE = Pattern.compile(" ");
+
+ private static JavaStreamingContext createContext(String ip,
+ int port,
+ String checkpointDirectory,
+ String outputPath) {
+
+ // If you do not see this printed, that means the StreamingContext has been loaded
+ // from the new checkpoint
+ System.out.println("Creating new context");
+ final File outputFile = new File(outputPath);
+ if (outputFile.exists()) {
+ outputFile.delete();
+ }
+ SparkConf sparkConf = new SparkConf().setAppName("JavaRecoverableNetworkWordCount");
+ // Create the context with a 1 second batch size
+ JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
+ ssc.checkpoint(checkpointDirectory);
+
+ // Create a socket stream on target ip:port and count the
+ // words in input stream of \n delimited text (eg. generated by 'nc')
+ JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port);
+ JavaDStream words = lines.flatMap(new FlatMapFunction() {
+ @Override
+ public Iterable call(String x) {
+ return Lists.newArrayList(SPACE.split(x));
+ }
+ });
+ JavaPairDStream wordCounts = words.mapToPair(
+ new PairFunction() {
+ @Override
+ public Tuple2 call(String s) {
+ return new Tuple2(s, 1);
+ }
+ }).reduceByKey(new Function2() {
+ @Override
+ public Integer call(Integer i1, Integer i2) {
+ return i1 + i2;
+ }
+ });
+
+ wordCounts.foreachRDD(new Function2, Time, Void>() {
+ @Override
+ public Void call(JavaPairRDD rdd, Time time) throws IOException {
+ String counts = "Counts at time " + time + " " + rdd.collect();
+ System.out.println(counts);
+ System.out.println("Appending to " + outputFile.getAbsolutePath());
+ Files.append(counts + "\n", outputFile, Charset.defaultCharset());
+ return null;
+ }
+ });
+
+ return ssc;
+ }
+
+ public static void main(String[] args) {
+ if (args.length != 4) {
+ System.err.println("You arguments were " + Arrays.asList(args));
+ System.err.println(
+ "Usage: JavaRecoverableNetworkWordCount \n" +
+ " . and describe the TCP server that Spark\n" +
+ " Streaming would connect to receive data. directory to\n" +
+ " HDFS-compatible file system which checkpoint data file to which\n" +
+ " the word counts will be appended\n" +
+ "\n" +
+ "In local mode, should be 'local[n]' with n > 1\n" +
+ "Both and must be absolute paths");
+ System.exit(1);
+ }
+
+ final String ip = args[0];
+ final int port = Integer.parseInt(args[1]);
+ final String checkpointDirectory = args[2];
+ final String outputPath = args[3];
+ JavaStreamingContextFactory factory = new JavaStreamingContextFactory() {
+ @Override
+ public JavaStreamingContext create() {
+ return createContext(ip, port, checkpointDirectory, outputPath);
+ }
+ };
+ JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, factory);
+ ssc.start();
+ ssc.awaitTermination();
+ }
+}
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
new file mode 100644
index 0000000000000..540dae785f6ea
--- /dev/null
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+An example of how to use SchemaRDD as a dataset for ML. Run with::
+ bin/spark-submit examples/src/main/python/mllib/dataset_example.py
+"""
+
+import os
+import sys
+import tempfile
+import shutil
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.stat import Statistics
+
+
+def summarize(dataset):
+ print "schema: %s" % dataset.schema().json()
+ labels = dataset.map(lambda r: r.label)
+ print "label average: %f" % labels.mean()
+ features = dataset.map(lambda r: r.features)
+ summary = Statistics.colStats(features)
+ print "features average: %r" % summary.mean()
+
+if __name__ == "__main__":
+ if len(sys.argv) > 2:
+ print >> sys.stderr, "Usage: dataset_example.py "
+ exit(-1)
+ sc = SparkContext(appName="DatasetExample")
+ sqlCtx = SQLContext(sc)
+ if len(sys.argv) == 2:
+ input = sys.argv[1]
+ else:
+ input = "data/mllib/sample_libsvm_data.txt"
+ points = MLUtils.loadLibSVMFile(sc, input)
+ dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+ summarize(dataset0)
+ tempdir = tempfile.NamedTemporaryFile(delete=False).name
+ os.unlink(tempdir)
+ print "Save dataset as a Parquet file to %s." % tempdir
+ dataset0.saveAsParquetFile(tempdir)
+ print "Load it back and summarize it again."
+ dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+ summarize(dataset1)
+ shutil.rmtree(tempdir)
diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py
new file mode 100644
index 0000000000000..99fef4276a369
--- /dev/null
+++ b/examples/src/main/python/mllib/word2vec.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This example uses text8 file from http://mattmahoney.net/dc/text8.zip
+# The file was downloadded, unziped and split into multiple lines using
+#
+# wget http://mattmahoney.net/dc/text8.zip
+# unzip text8.zip
+# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines
+# This was done so that the example can be run in local mode
+
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.mllib.feature import Word2Vec
+
+USAGE = ("bin/spark-submit --driver-memory 4g "
+ "examples/src/main/python/mllib/word2vec.py text8_lines")
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print USAGE
+ sys.exit("Argument for file not provided")
+ file_path = sys.argv[1]
+ sc = SparkContext(appName='Word2Vec')
+ inp = sc.textFile(file_path).map(lambda row: row.split(" "))
+
+ word2vec = Word2Vec()
+ model = word2vec.fit(inp)
+
+ synonyms = model.findSynonyms('china', 40)
+
+ for word, cosine_distance in synonyms:
+ print "{}: {}".format(word, cosine_distance)
+ sc.stop()
diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py
index b539c4128cdcc..a5f25d78c1146 100755
--- a/examples/src/main/python/pagerank.py
+++ b/examples/src/main/python/pagerank.py
@@ -15,6 +15,11 @@
# limitations under the License.
#
+"""
+This is an example implementation of PageRank. For more conventional use,
+Please refer to PageRank implementation provided by graphx
+"""
+
import re
import sys
from operator import add
@@ -40,6 +45,9 @@ def parseNeighbors(urls):
print >> sys.stderr, "Usage: pagerank "
exit(-1)
+ print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is
+ given as an example! Please refer to PageRank implementation provided by graphx"""
+
# Initialize the spark context.
sc = SparkContext(appName="PythonPageRank")
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index eefa022f1927c..d2c5ca48c6cb8 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -48,7 +48,7 @@
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
- path = os.environ['SPARK_HOME'] + "examples/src/main/resources/people.json"
+ path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json")
# Create a SchemaRDD from the file(s) pointed to by path
people = sqlContext.jsonFile(path)
# root
diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py
new file mode 100644
index 0000000000000..f7ffb5379681e
--- /dev/null
+++ b/examples/src/main/python/streaming/hdfs_wordcount.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in new text files created in the given directory
+ Usage: hdfs_wordcount.py
+ is the directory that Spark Streaming will use to find and read new text files.
+
+ To run this on your local machine on directory `localdir`, run this example
+ $ bin/spark-submit examples/src/main/python/streaming/hdfs_wordcount.py localdir
+
+ Then create a text file in `localdir` and the words in the file will get counted.
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print >> sys.stderr, "Usage: hdfs_wordcount.py "
+ exit(-1)
+
+ sc = SparkContext(appName="PythonStreamingHDFSWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.textFileStream(sys.argv[1])
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda x: (x, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py
new file mode 100644
index 0000000000000..cfa9c1ff5bfbc
--- /dev/null
+++ b/examples/src/main/python/streaming/network_wordcount.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ Usage: network_wordcount.py
+ and describe the TCP server that Spark Streaming would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: network_wordcount.py "
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py
new file mode 100644
index 0000000000000..fc6827c82bf9b
--- /dev/null
+++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in text encoded with UTF8 received from the network every second.
+
+ Usage: recoverable_network_wordcount.py
+ and describe the TCP server that Spark Streaming would connect to receive
+ data. directory to HDFS-compatible file system which checkpoint data
+ file to which the word counts will be appended
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/recoverable_network_wordcount.py \
+ localhost 9999 ~/checkpoint/ ~/out`
+
+ If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create
+ a new StreamingContext (will print "Creating new context" to the console). Otherwise, if
+ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
+ the checkpoint data.
+"""
+
+import os
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+
+def createContext(host, port, outputPath):
+ # If you do not see this printed, that means the StreamingContext has been loaded
+ # from the new checkpoint
+ print "Creating new context"
+ if os.path.exists(outputPath):
+ os.remove(outputPath)
+ sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ # Create a socket stream on target ip:port and count the
+ # words in input stream of \n delimited text (eg. generated by 'nc')
+ lines = ssc.socketTextStream(host, port)
+ words = lines.flatMap(lambda line: line.split(" "))
+ wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
+
+ def echo(time, rdd):
+ counts = "Counts at time %s %s" % (time, rdd.collect())
+ print counts
+ print "Appending to " + os.path.abspath(outputPath)
+ with open(outputPath, 'a') as f:
+ f.write(counts + "\n")
+
+ wordCounts.foreachRDD(echo)
+ return ssc
+
+if __name__ == "__main__":
+ if len(sys.argv) != 5:
+ print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\
+ ""
+ exit(-1)
+ host, port, checkpoint, output = sys.argv[1:]
+ ssc = StreamingContext.getOrCreate(checkpoint,
+ lambda: createContext(host, int(port), output))
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py
new file mode 100644
index 0000000000000..18a9a5a452ffb
--- /dev/null
+++ b/examples/src/main/python/streaming/stateful_network_wordcount.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the
+ network every second.
+
+ Usage: stateful_network_wordcount.py
+ and describe the TCP server that Spark Streaming
+ would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
+ localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: stateful_network_wordcount.py "
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+ ssc.checkpoint("checkpoint")
+
+ def updateFunc(new_values, last_sum):
+ return sum(new_values) + (last_sum or 0)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ running_counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .updateStateByKey(updateFunc)
+
+ running_counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
index 1f576319b3ca8..3d5259463003d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
@@ -17,11 +17,7 @@
package org.apache.spark.examples
-import scala.math.sqrt
-
-import cern.colt.matrix._
-import cern.colt.matrix.linalg._
-import cern.jet.math._
+import org.apache.commons.math3.linear._
/**
* Alternating least squares matrix factorization.
@@ -30,84 +26,70 @@ import cern.jet.math._
* please refer to org.apache.spark.mllib.recommendation.ALS
*/
object LocalALS {
+
// Parameters set through command line arguments
var M = 0 // Number of movies
var U = 0 // Number of users
var F = 0 // Number of features
var ITERATIONS = 0
-
val LAMBDA = 0.01 // Regularization coefficient
- // Some COLT objects
- val factory2D = DoubleFactory2D.dense
- val factory1D = DoubleFactory1D.dense
- val algebra = Algebra.DEFAULT
- val blas = SeqBlas.seqBlas
-
- def generateR(): DoubleMatrix2D = {
- val mh = factory2D.random(M, F)
- val uh = factory2D.random(U, F)
- algebra.mult(mh, algebra.transpose(uh))
+ def generateR(): RealMatrix = {
+ val mh = randomMatrix(M, F)
+ val uh = randomMatrix(U, F)
+ mh.multiply(uh.transpose())
}
- def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
- us: Array[DoubleMatrix1D]): Double =
- {
- val r = factory2D.make(M, U)
+ def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = {
+ val r = new Array2DRowRealMatrix(M, U)
for (i <- 0 until M; j <- 0 until U) {
- r.set(i, j, blas.ddot(ms(i), us(j)))
+ r.setEntry(i, j, ms(i).dotProduct(us(j)))
}
- blas.daxpy(-1, targetR, r)
- val sumSqs = r.aggregate(Functions.plus, Functions.square)
- sqrt(sumSqs / (M * U))
+ val diffs = r.subtract(targetR)
+ var sumSqs = 0.0
+ for (i <- 0 until M; j <- 0 until U) {
+ val diff = diffs.getEntry(i, j)
+ sumSqs += diff * diff
+ }
+ math.sqrt(sumSqs / (M.toDouble * U.toDouble))
}
- def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
+ def updateMovie(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = {
+ var XtX: RealMatrix = new Array2DRowRealMatrix(F, F)
+ var Xty: RealVector = new ArrayRealVector(F)
// For each user that rated the movie
for (j <- 0 until U) {
val u = us(j)
// Add u * u^t to XtX
- blas.dger(1, u, u, XtX)
+ XtX = XtX.add(u.outerProduct(u))
// Add u * rating to Xty
- blas.daxpy(R.get(i, j), u, Xty)
+ Xty = Xty.add(u.mapMultiply(R.getEntry(i, j)))
}
- // Add regularization coefs to diagonal terms
+ // Add regularization coefficients to diagonal terms
for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
+ XtX.addToEntry(d, d, LAMBDA * U)
}
// Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- solved2D.viewColumn(0)
+ new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
- def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
+ def updateUser(j: Int, u: RealVector, ms: Array[RealVector], R: RealMatrix) : RealVector = {
+ var XtX: RealMatrix = new Array2DRowRealMatrix(F, F)
+ var Xty: RealVector = new ArrayRealVector(F)
// For each movie that the user rated
for (i <- 0 until M) {
val m = ms(i)
// Add m * m^t to XtX
- blas.dger(1, m, m, XtX)
+ XtX = XtX.add(m.outerProduct(m))
// Add m * rating to Xty
- blas.daxpy(R.get(i, j), m, Xty)
+ Xty = Xty.add(m.mapMultiply(R.getEntry(i, j)))
}
- // Add regularization coefs to diagonal terms
+ // Add regularization coefficients to diagonal terms
for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
+ XtX.addToEntry(d, d, LAMBDA * M)
}
// Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- solved2D.viewColumn(0)
+ new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
def showWarning() {
@@ -135,21 +117,28 @@ object LocalALS {
showWarning()
- printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
+ println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS")
val R = generateR()
// Initialize m and u randomly
- var ms = Array.fill(M)(factory1D.random(F))
- var us = Array.fill(U)(factory1D.random(F))
+ var ms = Array.fill(M)(randomVector(F))
+ var us = Array.fill(U)(randomVector(F))
// Iteratively update movies then users
for (iter <- 1 to ITERATIONS) {
- println("Iteration " + iter + ":")
+ println(s"Iteration $iter:")
ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray
us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray
println("RMSE = " + rmse(R, ms, us))
println()
}
}
+
+ private def randomVector(n: Int): RealVector =
+ new ArrayRealVector(Array.fill(n)(math.random))
+
+ private def randomMatrix(rows: Int, cols: Int): RealMatrix =
+ new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random))
+
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
index 931faac5463c4..ac2ea35bbd0e0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala
@@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector}
* Logistic regression based classification.
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object LocalFileLR {
val D = 10 // Numer of dimensions
@@ -41,7 +42,8 @@ object LocalFileLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
index 2d75b9d2590f8..92a683ad57ea1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala
@@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector}
* Logistic regression based classification.
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object LocalLR {
val N = 10000 // Number of data points
@@ -48,7 +49,8 @@ object LocalLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
index fde8ffeedf8b4..6c0ac8013ce34 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
@@ -17,11 +17,7 @@
package org.apache.spark.examples
-import scala.math.sqrt
-
-import cern.colt.matrix._
-import cern.colt.matrix.linalg._
-import cern.jet.math._
+import org.apache.commons.math3.linear._
import org.apache.spark._
@@ -32,62 +28,53 @@ import org.apache.spark._
* please refer to org.apache.spark.mllib.recommendation.ALS
*/
object SparkALS {
+
// Parameters set through command line arguments
var M = 0 // Number of movies
var U = 0 // Number of users
var F = 0 // Number of features
var ITERATIONS = 0
-
val LAMBDA = 0.01 // Regularization coefficient
- // Some COLT objects
- val factory2D = DoubleFactory2D.dense
- val factory1D = DoubleFactory1D.dense
- val algebra = Algebra.DEFAULT
- val blas = SeqBlas.seqBlas
-
- def generateR(): DoubleMatrix2D = {
- val mh = factory2D.random(M, F)
- val uh = factory2D.random(U, F)
- algebra.mult(mh, algebra.transpose(uh))
+ def generateR(): RealMatrix = {
+ val mh = randomMatrix(M, F)
+ val uh = randomMatrix(U, F)
+ mh.multiply(uh.transpose())
}
- def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
- us: Array[DoubleMatrix1D]): Double =
- {
- val r = factory2D.make(M, U)
+ def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = {
+ val r = new Array2DRowRealMatrix(M, U)
for (i <- 0 until M; j <- 0 until U) {
- r.set(i, j, blas.ddot(ms(i), us(j)))
+ r.setEntry(i, j, ms(i).dotProduct(us(j)))
}
- blas.daxpy(-1, targetR, r)
- val sumSqs = r.aggregate(Functions.plus, Functions.square)
- sqrt(sumSqs / (M * U))
+ val diffs = r.subtract(targetR)
+ var sumSqs = 0.0
+ for (i <- 0 until M; j <- 0 until U) {
+ val diff = diffs.getEntry(i, j)
+ sumSqs += diff * diff
+ }
+ math.sqrt(sumSqs / (M.toDouble * U.toDouble))
}
- def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
+ def update(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = {
val U = us.size
- val F = us(0).size
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
+ val F = us(0).getDimension
+ var XtX: RealMatrix = new Array2DRowRealMatrix(F, F)
+ var Xty: RealVector = new ArrayRealVector(F)
// For each user that rated the movie
for (j <- 0 until U) {
val u = us(j)
// Add u * u^t to XtX
- blas.dger(1, u, u, XtX)
+ XtX = XtX.add(u.outerProduct(u))
// Add u * rating to Xty
- blas.daxpy(R.get(i, j), u, Xty)
+ Xty = Xty.add(u.mapMultiply(R.getEntry(i, j)))
}
// Add regularization coefs to diagonal terms
for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
+ XtX.addToEntry(d, d, LAMBDA * U)
}
// Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- solved2D.viewColumn(0)
+ new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
def showWarning() {
@@ -118,7 +105,7 @@ object SparkALS {
showWarning()
- printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
+ println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS")
val sparkConf = new SparkConf().setAppName("SparkALS")
val sc = new SparkContext(sparkConf)
@@ -126,21 +113,21 @@ object SparkALS {
val R = generateR()
// Initialize m and u randomly
- var ms = Array.fill(M)(factory1D.random(F))
- var us = Array.fill(U)(factory1D.random(F))
+ var ms = Array.fill(M)(randomVector(F))
+ var us = Array.fill(U)(randomVector(F))
// Iteratively update movies then users
val Rc = sc.broadcast(R)
var msb = sc.broadcast(ms)
var usb = sc.broadcast(us)
for (iter <- 1 to ITERATIONS) {
- println("Iteration " + iter + ":")
+ println(s"Iteration $iter:")
ms = sc.parallelize(0 until M, slices)
.map(i => update(i, msb.value(i), usb.value, Rc.value))
.collect()
msb = sc.broadcast(ms) // Re-broadcast ms because it was updated
us = sc.parallelize(0 until U, slices)
- .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value)))
+ .map(i => update(i, usb.value(i), msb.value, Rc.value.transpose()))
.collect()
usb = sc.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
@@ -149,4 +136,11 @@ object SparkALS {
sc.stop()
}
+
+ private def randomVector(n: Int): RealVector =
+ new ArrayRealVector(Array.fill(n)(math.random))
+
+ private def randomMatrix(rows: Int, cols: Int): RealMatrix =
+ new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random))
+
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index 3258510894372..9099c2fcc90b3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -32,7 +32,8 @@ import org.apache.spark.scheduler.InputFormatInfo
* Logistic regression based classification.
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object SparkHdfsLR {
val D = 10 // Numer of dimensions
@@ -54,7 +55,8 @@ object SparkHdfsLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
index fc23308fc4adf..257a7d29f922a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala
@@ -30,7 +30,8 @@ import org.apache.spark._
* Usage: SparkLR [slices]
*
* This is an example implementation for learning how to use Spark. For more conventional use,
- * please refer to org.apache.spark.mllib.classification.LogisticRegression
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object SparkLR {
val N = 10000 // Number of data points
@@ -53,7 +54,8 @@ object SparkLR {
def showWarning() {
System.err.println(
"""WARN: This is a naive implementation of Logistic Regression and is given as an example!
- |Please use the LogisticRegression method found in org.apache.spark.mllib.classification
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
|for more conventional use.
""".stripMargin)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
index 4c7e006da0618..8d092b6506d33 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala
@@ -28,13 +28,28 @@ import org.apache.spark.{SparkConf, SparkContext}
* URL neighbor URL
* ...
* where URL and their neighbors are separated by space(s).
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to org.apache.spark.graphx.lib.PageRank
*/
object SparkPageRank {
+
+ def showWarning() {
+ System.err.println(
+ """WARN: This is a naive implementation of PageRank and is given as an example!
+ |Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank
+ |for more conventional use.
+ """.stripMargin)
+ }
+
def main(args: Array[String]) {
if (args.length < 1) {
System.err.println("Usage: SparkPageRank ")
System.exit(1)
}
+
+ showWarning()
+
val sparkConf = new SparkConf().setAppName("PageRank")
val iters = if (args.length > 0) args(1).toInt else 10
val ctx = new SparkContext(sparkConf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
index 96d13612e46dd..4393b99e636b6 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala
@@ -32,11 +32,24 @@ import org.apache.spark.storage.StorageLevel
/**
* Logistic regression based classification.
* This example uses Tachyon to persist rdds during computation.
+ *
+ * This is an example implementation for learning how to use Spark. For more conventional use,
+ * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs.
*/
object SparkTachyonHdfsLR {
val D = 10 // Numer of dimensions
val rand = new Random(42)
+ def showWarning() {
+ System.err.println(
+ """WARN: This is a naive implementation of Logistic Regression and is given as an example!
+ |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or
+ |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+ |for more conventional use.
+ """.stripMargin)
+ }
+
case class DataPoint(x: Vector[Double], y: Double)
def parsePoint(line: String): DataPoint = {
@@ -51,6 +64,9 @@ object SparkTachyonHdfsLR {
}
def main(args: Array[String]) {
+
+ showWarning()
+
val inputPath = args(0)
val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR")
val conf = new Configuration()
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
index e06f4dcd54442..e322d4ce5a745 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
@@ -18,17 +18,7 @@
package org.apache.spark.examples.bagel
import org.apache.spark._
-import org.apache.spark.SparkContext._
-import org.apache.spark.serializer.KryoRegistrator
-
import org.apache.spark.bagel._
-import org.apache.spark.bagel.Bagel._
-
-import scala.collection.mutable.ArrayBuffer
-
-import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
-
-import com.esotericsoftware.kryo._
class PageRankUtils extends Serializable {
def computeWithCombiner(numVertices: Long, epsilon: Double)(
@@ -99,13 +89,6 @@ class PRMessage() extends Message[String] with Serializable {
}
}
-class PRKryoRegistrator extends KryoRegistrator {
- def registerClasses(kryo: Kryo) {
- kryo.register(classOf[PRVertex])
- kryo.register(classOf[PRMessage])
- }
-}
-
class CustomPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
index e4db3ec51313d..859abedf2a55e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
@@ -38,8 +38,7 @@ object WikipediaPageRank {
}
val sparkConf = new SparkConf()
sparkConf.setAppName("WikipediaPageRank")
- sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
+ sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage]))
val inputFile = args(0)
val threshold = args(1).toDouble
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
index c4317a6aec798..828cffb01ca1e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
@@ -46,28 +46,15 @@ object Analytics extends Logging {
}
val options = mutable.Map(optionsList: _*)
- def pickPartitioner(v: String): PartitionStrategy = {
- // TODO: Use reflection rather than listing all the partitioning strategies here.
- v match {
- case "RandomVertexCut" => RandomVertexCut
- case "EdgePartition1D" => EdgePartition1D
- case "EdgePartition2D" => EdgePartition2D
- case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut
- case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v)
- }
- }
-
- val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
- .set("spark.locality.wait", "100000")
+ val conf = new SparkConf().set("spark.locality.wait", "100000")
+ GraphXUtils.registerKryoClasses(conf)
val numEPart = options.remove("numEPart").map(_.toInt).getOrElse {
println("Set the number of edge partitions using --numEPart.")
sys.exit(1)
}
val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy")
- .map(pickPartitioner(_))
+ .map(PartitionStrategy.fromString(_))
val edgeStorageLevel = options.remove("edgeStorageLevel")
.map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY)
val vertexStorageLevel = options.remove("vertexStorageLevel")
@@ -90,7 +77,7 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")"))
val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel).cache()
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
@@ -107,7 +94,7 @@ object Analytics extends Logging {
if (!outFname.isEmpty) {
logWarning("Saving pageranks of pages to " + outFname)
- pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
+ pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname)
}
sc.stop()
@@ -123,13 +110,13 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")"))
val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel).cache()
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
val cc = ConnectedComponents.run(graph)
- println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct())
+ println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct())
sc.stop()
case "triangles" =>
@@ -144,10 +131,10 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")"))
val graph = GraphLoader.edgeListFile(sc, fname,
canonicalOrientation = true,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel)
- // TriangleCount requires the graph to be partitioned
+ // TriangleCount requires the graph to be partitioned
.partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache()
val triangles = TriangleCount.run(graph)
println("Triangles: " + triangles.vertices.map {
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
index 5f35a5836462e..3ec20d594b784 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
@@ -18,7 +18,7 @@
package org.apache.spark.examples.graphx
import org.apache.spark.SparkContext._
-import org.apache.spark.graphx.PartitionStrategy
+import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.graphx.util.GraphGenerators
import java.io.{PrintWriter, FileOutputStream}
@@ -67,7 +67,7 @@ object SynthBenchmark {
options.foreach {
case ("app", v) => app = v
- case ("niter", v) => niter = v.toInt
+ case ("niters", v) => niter = v.toInt
case ("nverts", v) => numVertices = v.toInt
case ("numEPart", v) => numEPart = Some(v.toInt)
case ("partStrategy", v) => partitionStrategy = Some(PartitionStrategy.fromString(v))
@@ -80,8 +80,7 @@ object SynthBenchmark {
val conf = new SparkConf()
.setAppName(s"GraphX Synth Benchmark (nverts = $numVertices, app = $app)")
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ GraphXUtils.registerKryoClasses(conf)
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
new file mode 100644
index 0000000000000..ee7897d9062d9
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.beans.BeanInfo
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.sql.SQLContext
+
+@BeanInfo
+case class LabeledDocument(id: Long, text: String, label: Double)
+
+@BeanInfo
+case class Document(id: Long, text: String)
+
+/**
+ * A simple text classification pipeline that recognizes "spark" from input text. This is to show
+ * how to create and configure an ML pipeline. Run with
+ * {{{
+ * bin/run-example ml.SimpleTextClassificationPipeline
+ * }}}
+ */
+object SimpleTextClassificationPipeline {
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ // Prepare training documents, which are labeled.
+ val training = sparkContext.parallelize(Seq(
+ LabeledDocument(0L, "a b c d e spark", 1.0),
+ LabeledDocument(1L, "b d", 0.0),
+ LabeledDocument(2L, "spark f g h", 1.0),
+ LabeledDocument(3L, "hadoop mapreduce", 0.0)))
+
+ // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
+ val tokenizer = new Tokenizer()
+ .setInputCol("text")
+ .setOutputCol("words")
+ val hashingTF = new HashingTF()
+ .setNumFeatures(1000)
+ .setInputCol(tokenizer.getOutputCol)
+ .setOutputCol("features")
+ val lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(0.01)
+ val pipeline = new Pipeline()
+ .setStages(Array(tokenizer, hashingTF, lr))
+
+ // Fit the pipeline to training documents.
+ val model = pipeline.fit(training)
+
+ // Prepare test documents, which are unlabeled.
+ val test = sparkContext.parallelize(Seq(
+ Document(4L, "spark i j k"),
+ Document(5L, "l m n"),
+ Document(6L, "mapreduce spark"),
+ Document(7L, "apache hadoop")))
+
+ // Make predictions on test documents.
+ model.transform(test)
+ .select('id, 'text, 'score, 'prediction)
+ .collect()
+ .foreach(println)
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala
new file mode 100644
index 0000000000000..ae6057758d6fc
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scala.reflect.runtime.universe._
+
+/**
+ * Abstract class for parameter case classes.
+ * This overrides the [[toString]] method to print all case class fields by name and value.
+ * @tparam T Concrete parameter class.
+ */
+abstract class AbstractParams[T: TypeTag] {
+
+ private def tag: TypeTag[T] = typeTag[T]
+
+ /**
+ * Finds all case class fields in concrete class instance, and outputs them in JSON-style format:
+ * {
+ * [field name]:\t[field value]\n
+ * [field name]:\t[field value]\n
+ * ...
+ * }
+ */
+ override def toString: String = {
+ val tpe = tag.tpe
+ val allAccessors = tpe.declarations.collect {
+ case m: MethodSymbol if m.isCaseAccessor => m
+ }
+ val mirror = runtimeMirror(getClass.getClassLoader)
+ val instanceMirror = mirror.reflect(this)
+ allAccessors.map { f =>
+ val paramName = f.name.toString
+ val fieldMirror = instanceMirror.reflectField(f)
+ val paramValue = fieldMirror.get
+ s" $paramName:\t$paramValue"
+ }.mkString("{\n", ",\n", "\n}")
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
index a6f78d2441db1..a113653810b93 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
@@ -55,7 +55,7 @@ object BinaryClassification {
stepSize: Double = 1.0,
algorithm: Algorithm = LR,
regType: RegType = L2,
- regParam: Double = 0.1)
+ regParam: Double = 0.01) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
index d6b2fe430e5a4..e49129c4e7844 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala
@@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object Correlations {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
new file mode 100644
index 0000000000000..cb1abbd18fd4d
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix}
+import org.apache.spark.{SparkConf, SparkContext}
+
+/**
+ * Compute the similar columns of a matrix, using cosine similarity.
+ *
+ * The input matrix must be stored in row-oriented dense format, one line per row with its entries
+ * separated by space. For example,
+ * {{{
+ * 0.5 1.0
+ * 2.0 3.0
+ * 4.0 5.0
+ * }}}
+ * represents a 3-by-2 matrix, whose first row is (0.5, 1.0).
+ *
+ * Example invocation:
+ *
+ * bin/run-example mllib.CosineSimilarity \
+ * --threshold 0.1 data/mllib/sample_svm_data.txt
+ */
+object CosineSimilarity {
+ case class Params(inputFile: String = null, threshold: Double = 0.1)
+ extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("CosineSimilarity") {
+ head("CosineSimilarity: an example app.")
+ opt[Double]("threshold")
+ .required()
+ .text(s"threshold similarity: to tradeoff computation vs quality estimate")
+ .action((x, c) => c.copy(threshold = x))
+ arg[String]("")
+ .required()
+ .text(s"input file, one row per line, space-separated")
+ .action((x, c) => c.copy(inputFile = x))
+ note(
+ """
+ |For example, the following command runs this app on a dataset:
+ |
+ | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \
+ | examplesjar.jar \
+ | --threshold 0.1 data/mllib/sample_svm_data.txt
+ """.stripMargin)
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ } getOrElse {
+ System.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName("CosineSimilarity")
+ val sc = new SparkContext(conf)
+
+ // Load and parse the data file.
+ val rows = sc.textFile(params.inputFile).map { line =>
+ val values = line.split(' ').map(_.toDouble)
+ Vectors.dense(values)
+ }.cache()
+ val mat = new RowMatrix(rows)
+
+ // Compute similar columns perfectly, with brute force.
+ val exact = mat.columnSimilarities()
+
+ // Compute similar columns with estimation using DIMSUM
+ val approx = mat.columnSimilarities(params.threshold)
+
+ val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) }
+ val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) }
+ val MAE = exactEntries.leftOuterJoin(approxEntries).values.map {
+ case (u, Some(v)) =>
+ math.abs(u - v)
+ case (u, None) =>
+ math.abs(u)
+ }.mean()
+
+ println(s"Average absolute error in estimate is: $MAE")
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
new file mode 100644
index 0000000000000..f8d83f4ec7327
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.io.File
+
+import com.google.common.io.Files
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+
+/**
+ * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DatasetExample {
+
+ case class Params(
+ input: String = "data/mllib/sample_libsvm_data.txt",
+ dataFormat: String = "libsvm") extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DatasetExample") {
+ head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ opt[String]("input")
+ .text(s"input path to dataset")
+ .action((x, c) => c.copy(input = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ success
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._ // for implicit conversions
+
+ // Load input data
+ val origData: RDD[LabeledPoint] = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
+ }
+ println(s"Loaded ${origData.count()} instances from file: ${params.input}")
+
+ // Convert input data to SchemaRDD explicitly.
+ val schemaRDD: SchemaRDD = origData
+ println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
+ println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+
+ // Select columns, using implicit conversion to SchemaRDD.
+ val labelsSchemaRDD: SchemaRDD = origData.select('label)
+ val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ val numLabels = labels.count()
+ val meanLabel = labels.fold(0.0)(_ + _) / numLabels
+ println(s"Selected label column with average value $meanLabel")
+
+ val featuresSchemaRDD: SchemaRDD = origData.select('features)
+ val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val outputDir = new File(tmpDir, "dataset").toString
+ println(s"Saving to $outputDir as Parquet file.")
+ schemaRDD.saveAsParquetFile(outputDir)
+
+ println(s"Loading Parquet file with UDT from $outputDir.")
+ val newDataset = sqlContext.parquetFile(outputDir)
+
+ println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
+ val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+
+ sc.stop()
+ }
+
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 4adc91d2fbe65..98f9d1689c8e7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -22,11 +22,11 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
+import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -62,7 +62,10 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
- fracTest: Double = 0.2)
+ fracTest: Double = 0.2,
+ useNodeIdCache: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -102,6 +105,21 @@ object DecisionTreeRunner {
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("useNodeIdCache")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.useNodeIdCache}")
+ .action((x, c) => c.copy(useNodeIdCache = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }}")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
@@ -136,18 +154,30 @@ object DecisionTreeRunner {
}
}
- def run(params: Params) {
-
- val conf = new SparkConf().setAppName("DecisionTreeRunner")
- val sc = new SparkContext(conf)
-
+ /**
+ * Load training and test data from files.
+ * @param input Path to input dataset.
+ * @param dataFormat "libsvm" or "dense"
+ * @param testInput Path to test dataset.
+ * @param algo Classification or Regression
+ * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
+ * @return (training dataset, test dataset, number of classes),
+ * where the number of classes is inferred from data (and set to 0 for Regression)
+ */
+ private[mllib] def loadDatasets(
+ sc: SparkContext,
+ input: String,
+ dataFormat: String,
+ testInput: String,
+ algo: Algo,
+ fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = {
// Load training data and cache it.
- val origExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ val origExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache()
}
// For classification, re-index classes if needed.
- val (examples, classIndexMap, numClasses) = params.algo match {
+ val (examples, classIndexMap, numClasses) = algo match {
case Classification => {
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
@@ -185,13 +215,14 @@ object DecisionTreeRunner {
}
// Create training, test sets.
- val splits = if (params.testInput != "") {
+ val splits = if (testInput != "") {
// Load testInput.
- val origTestExamples = params.dataFormat match {
- case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
- case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput)
+ val numFeatures = examples.take(1)(0).features.size
+ val origTestExamples = dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, testInput)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
}
- params.algo match {
+ algo match {
case Classification => {
// classCounts: class --> # examples in class
val testExamples = {
@@ -208,17 +239,31 @@ object DecisionTreeRunner {
}
} else {
// Split input into training, test.
- examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
+ examples.randomSplit(Array(1.0 - fracTest, fracTest))
}
val training = splits(0).cache()
val test = splits(1).cache()
+
val numTraining = training.count()
val numTest = test.count()
-
println(s"numTraining = $numTraining, numTest = $numTest.")
examples.unpersist(blocking = false)
+ (training, test, numClasses)
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"DecisionTreeRunner with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat,
+ params.testInput, params.algo, params.fracTest)
+
val impurityCalculator = params.impurity match {
case Gini => impurity.Gini
case Entropy => impurity.Entropy
@@ -233,9 +278,15 @@ object DecisionTreeRunner {
maxBins = params.maxBins,
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
- minInfoGain = params.minInfoGain)
+ minInfoGain = params.minInfoGain,
+ useNodeIdCache = params.useNodeIdCache,
+ checkpointDir = params.checkpointDir,
+ checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
+ val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
if (model.numNodes < 20) {
println(model.toDebugString) // Print full model.
} else {
@@ -259,8 +310,11 @@ object DecisionTreeRunner {
} else {
val randomSeed = Utils.random.nextInt()
if (params.algo == Classification) {
+ val startTime = System.nanoTime()
val model = RandomForest.trainClassifier(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
@@ -275,8 +329,11 @@ object DecisionTreeRunner {
println(s"Test accuracy = $testAccuracy")
}
if (params.algo == Regression) {
+ val startTime = System.nanoTime()
val model = RandomForest.trainRegressor(training, strategy, params.numTrees,
params.featureSubsetStrategy, randomSeed)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
if (model.totalNumNodes < 30) {
println(model.toDebugString) // Print full model.
} else {
@@ -295,19 +352,11 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = {
- data.map { y =>
- val err = tree.predict(y.features) - y.label
- err * err
- }.mean()
- }
-
- /**
- * Calculates the mean squared error for regression.
- */
- private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
+ private[mllib] def meanSquaredError(
+ model: { def predict(features: Vector): Double },
+ data: RDD[LabeledPoint]): Double = {
data.map { y =>
- val err = tree.predict(y.features) - y.label
+ val err = model.predict(y.features) - y.label
err * err
}.mean()
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
index 89dfa26c2299c..11e35598baf50 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
@@ -44,7 +44,7 @@ object DenseKMeans {
input: String = null,
k: Int = -1,
numIterations: Int = 10,
- initializationMode: InitializationMode = Parallel)
+ initializationMode: InitializationMode = Parallel) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
new file mode 100644
index 0000000000000..1def8b45a230c
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.mllib.tree.GradientBoostedTrees
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo}
+import org.apache.spark.util.Utils
+
+/**
+ * An example runner for Gradient Boosting using decision trees as weak learners. Run with
+ * {{{
+ * ./bin/run-example mllib.GradientBoostedTreesRunner [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
+ */
+object GradientBoostedTreesRunner {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ numIterations: Int = 10,
+ fracTest: Double = 0.2) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("GradientBoostedTrees") {
+ head("GradientBoostedTrees: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numIterations")
+ .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}")
+ .action((x, c) => c.copy(numIterations = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params")
+ val sc = new SparkContext(conf)
+
+ println(s"GradientBoostedTreesRunner with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input,
+ params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest)
+
+ val boostingStrategy = BoostingStrategy.defaultParams(params.algo)
+ boostingStrategy.treeStrategy.numClassesForClassification = numClasses
+ boostingStrategy.numIterations = params.numIterations
+ boostingStrategy.treeStrategy.maxDepth = params.maxDepth
+
+ val randomSeed = Utils.random.nextInt()
+ if (params.algo == "Classification") {
+ val startTime = System.nanoTime()
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainAccuracy =
+ new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label)))
+ .precision
+ println(s"Train accuracy = $trainAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision
+ println(s"Test accuracy = $testAccuracy")
+ } else if (params.algo == "Regression") {
+ val startTime = System.nanoTime()
+ val model = GradientBoostedTrees.train(training, boostingStrategy)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+ if (model.totalNumNodes < 30) {
+ println(model.toDebugString) // Print full model.
+ } else {
+ println(model) // Print model summary.
+ }
+ val trainMSE = DecisionTreeRunner.meanSquaredError(model, training)
+ println(s"Train mean squared error = $trainMSE")
+ val testMSE = DecisionTreeRunner.meanSquaredError(model, test)
+ println(s"Test mean squared error = $testMSE")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
index 05b7d66f8dffd..6a456ba7ec07b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1U
* A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt`.
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
-object LinearRegression extends App {
+object LinearRegression {
object RegType extends Enumeration {
type RegType = Value
@@ -47,42 +47,44 @@ object LinearRegression extends App {
numIterations: Int = 100,
stepSize: Double = 1.0,
regType: RegType = L2,
- regParam: Double = 0.1)
-
- val defaultParams = Params()
-
- val parser = new OptionParser[Params]("LinearRegression") {
- head("LinearRegression: an example app for linear regression.")
- opt[Int]("numIterations")
- .text("number of iterations")
- .action((x, c) => c.copy(numIterations = x))
- opt[Double]("stepSize")
- .text(s"initial step size, default: ${defaultParams.stepSize}")
- .action((x, c) => c.copy(stepSize = x))
- opt[String]("regType")
- .text(s"regularization type (${RegType.values.mkString(",")}), " +
- s"default: ${defaultParams.regType}")
- .action((x, c) => c.copy(regType = RegType.withName(x)))
- opt[Double]("regParam")
- .text(s"regularization parameter, default: ${defaultParams.regParam}")
- arg[String]("")
- .required()
- .text("input paths to labeled examples in LIBSVM format")
- .action((x, c) => c.copy(input = x))
- note(
- """
- |For example, the following command runs this app on a synthetic dataset:
- |
- | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \
- | examples/target/scala-*/spark-examples-*.jar \
- | data/mllib/sample_linear_regression_data.txt
- """.stripMargin)
- }
+ regParam: Double = 0.01) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LinearRegression") {
+ head("LinearRegression: an example app for linear regression.")
+ opt[Int]("numIterations")
+ .text("number of iterations")
+ .action((x, c) => c.copy(numIterations = x))
+ opt[Double]("stepSize")
+ .text(s"initial step size, default: ${defaultParams.stepSize}")
+ .action((x, c) => c.copy(stepSize = x))
+ opt[String]("regType")
+ .text(s"regularization type (${RegType.values.mkString(",")}), " +
+ s"default: ${defaultParams.regType}")
+ .action((x, c) => c.copy(regType = RegType.withName(x)))
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ arg[String]("")
+ .required()
+ .text("input paths to labeled examples in LIBSVM format")
+ .action((x, c) => c.copy(input = x))
+ note(
+ """
+ |For example, the following command runs this app on a synthetic dataset:
+ |
+ | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \
+ | examples/target/scala-*/spark-examples-*.jar \
+ | data/mllib/sample_linear_regression_data.txt
+ """.stripMargin)
+ }
- parser.parse(args, defaultParams).map { params =>
- run(params)
- } getOrElse {
- sys.exit(1)
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ } getOrElse {
+ sys.exit(1)
+ }
}
def run(params: Params) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 98aaedb9d7dc9..91a0a860d6c71 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -19,7 +19,6 @@ package org.apache.spark.examples.mllib
import scala.collection.mutable
-import com.esotericsoftware.kryo.Kryo
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
@@ -27,7 +26,6 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
-import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
/**
* An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
@@ -40,13 +38,6 @@ import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
*/
object MovieLensALS {
- class ALSRegistrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[Rating])
- kryo.register(classOf[mutable.BitSet])
- }
- }
-
case class Params(
input: String = null,
kryo: Boolean = false,
@@ -55,7 +46,7 @@ object MovieLensALS {
rank: Int = 10,
numUserBlocks: Int = -1,
numProductBlocks: Int = -1,
- implicitPrefs: Boolean = false)
+ implicitPrefs: Boolean = false) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -108,17 +99,18 @@ object MovieLensALS {
def run(params: Params) {
val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
if (params.kryo) {
- conf.set("spark.serializer", classOf[KryoSerializer].getName)
- .set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+ conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating]))
.set("spark.kryoserializer.buffer.mb", "8")
}
val sc = new SparkContext(conf)
Logger.getRootLogger.setLevel(Level.WARN)
+ val implicitPrefs = params.implicitPrefs
+
val ratings = sc.textFile(params.input).map { line =>
val fields = line.split("::")
- if (params.implicitPrefs) {
+ if (implicitPrefs) {
/*
* MovieLens ratings are on a scale of 1-5:
* 5: Must see
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
index 4532512c01f84..6e4e2d07f284b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala
@@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext}
object MultivariateSummarizer {
case class Params(input: String = "data/mllib/sample_linear_regression_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
index f01b8266e3fe3..663c12734af68 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala
@@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._
object SampledRDDs {
case class Params(input: String = "data/mllib/sample_binary_classification_data.txt")
+ extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
index 952fa2a5109a4..f1ff4e6911f5e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala
@@ -37,7 +37,7 @@ object SparseNaiveBayes {
input: String = null,
minPartitions: Int = 0,
numFeatures: Int = -1,
- lambda: Double = 1.0)
+ lambda: Double = 1.0) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
new file mode 100644
index 0000000000000..33e5760aed997
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+
+/**
+ * Estimate clusters on one stream of data and make predictions
+ * on another stream, where the data streams arrive as text files
+ * into two different directories.
+ *
+ * The rows of the training text files must be vector data in the form
+ * `[x1,x2,x3,...,xn]`
+ * Where n is the number of dimensions.
+ *
+ * The rows of the test text files must be labeled data in the form
+ * `(y,[x1,x2,x3,...,xn])`
+ * Where y is some identifier. n must be the same for train and test.
+ *
+ * Usage: StreamingKmeans
+ *
+ * To run on your local machine using the two directories `trainingDir` and `testDir`,
+ * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call:
+ * $ bin/run-example \
+ * org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2
+ *
+ * As you add text files to `trainingDir` the clusters will continuously update.
+ * Anytime you add text files to `testDir`, you'll see predicted labels using the current model.
+ *
+ */
+object StreamingKMeans {
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println(
+ "Usage: StreamingKMeans " +
+ "")
+ System.exit(1)
+ }
+
+ val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
+ val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
+
+ val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
+ val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
+
+ val model = new StreamingKMeans()
+ .setK(args(3).toInt)
+ .setDecayFactor(1.0)
+ .setRandomCenters(args(4).toInt, 0.0)
+
+ model.trainOn(trainingData)
+ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
+
+ ssc.start()
+ ssc.awaitTermination()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
index e26f213e8afa8..227acc117502d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
@@ -27,15 +27,16 @@ object HiveFromSpark {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("HiveFromSpark")
val sc = new SparkContext(sparkConf)
+ val path = s"${System.getenv("SPARK_HOME")}/examples/src/main/resources/kv1.txt"
- // A local hive context creates an instance of the Hive Metastore in process, storing the
+ // A local hive context creates an instance of the Hive Metastore in process, storing
// the warehouse data in the current directory. This location can be overridden by
// specifying a second parameter to the constructor.
val hiveContext = new HiveContext(sc)
import hiveContext._
sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
- sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src")
+ sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE src")
// Queries are expressed in HiveQL
println("Result of 'SELECT *': ")
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
index 6af3a0f33efc2..19427e629f76d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
@@ -31,15 +31,13 @@ import org.apache.spark.util.IntParam
/**
* Counts words in text encoded with UTF8 received from the network every second.
*
- * Usage: NetworkWordCount
+ * Usage: RecoverableNetworkWordCount
* and describe the TCP server that Spark Streaming would connect to receive
* data. directory to HDFS-compatible file system which checkpoint data
* file to which the word counts will be appended
*
- * In local mode, should be 'local[n]' with n > 1
* and must be absolute paths
*
- *
* To run this on your local machine, you need to first run a Netcat server
*
* `$ nc -lk 9999`
@@ -54,22 +52,11 @@ import org.apache.spark.util.IntParam
* checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
* the checkpoint data.
*
- * To run this example in a local standalone cluster with automatic driver recovery,
- *
- * `$ bin/spark-class org.apache.spark.deploy.Client -s launch \
- * \
- * org.apache.spark.examples.streaming.RecoverableNetworkWordCount \
- * localhost 9999 ~/checkpoint ~/out`
- *
- * would typically be
- * /examples/target/scala-XX/spark-examples....jar
- *
* Refer to the online documentation for more details.
*/
-
object RecoverableNetworkWordCount {
- def createContext(ip: String, port: Int, outputPath: String) = {
+ def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = {
// If you do not see this printed, that means the StreamingContext has been loaded
// from the new checkpoint
@@ -79,6 +66,7 @@ object RecoverableNetworkWordCount {
val sparkConf = new SparkConf().setAppName("RecoverableNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
+ ssc.checkpoint(checkpointDirectory)
// Create a socket stream on target ip:port and count the
// words in input stream of \n delimited text (eg. generated by 'nc')
@@ -114,7 +102,7 @@ object RecoverableNetworkWordCount {
val Array(ip, IntParam(port), checkpointDirectory, outputPath) = args
val ssc = StreamingContext.getOrCreate(checkpointDirectory,
() => {
- createContext(ip, port, outputPath)
+ createContext(ip, port, outputPath, checkpointDirectory)
})
ssc.start()
ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index a4d159bf38377..ed186ea5650c4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -18,12 +18,13 @@
package org.apache.spark.examples.streaming
import org.apache.spark.SparkConf
+import org.apache.spark.HashPartitioner
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
- * second.
+ * second starting with initial value of word count.
* Usage: StatefulNetworkWordCount
* and describe the TCP server that Spark Streaming would connect to receive
* data.
@@ -51,12 +52,19 @@ object StatefulNetworkWordCount {
Some(currentCount + previousCount)
}
+ val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
+ iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
+ }
+
val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")
- // Create a NetworkInputDStream on target ip:port and count the
+ // Initial RDD input to updateStateByKey
+ val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
+
+ // Create a ReceiverInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
val words = lines.flatMap(_.split(" "))
@@ -64,7 +72,8 @@ object StatefulNetworkWordCount {
// Update the cumulative count using updateStateByKey
// This will give a Dstream made of state (which is the cumulative count of the words)
- val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
+ val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
+ new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
stateDstream.print()
ssc.start()
ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
index d9b886eff77cc..55226c0a6df60 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
@@ -50,7 +50,7 @@ object PageViewStream {
val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1),
System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq)
- // Create a NetworkInputDStream on target host:port and convert each line to a PageView
+ // Create a ReceiverInputDStream on target host:port and convert each line to a PageView
val pageViews = ssc.socketTextStream(host, port)
.flatMap(_.split("\n"))
.map(PageView.fromString(_))
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index ac291bd4fde20..72618b6515f83 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties
new file mode 100644
index 0000000000000..4411d6e20c52a
--- /dev/null
+++ b/external/flume-sink/src/test/resources/log4j.properties
@@ -0,0 +1,29 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file streaming/target/unit-tests.log
+log4j.rootCategory=INFO, file
+# log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+
diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
index a2b2cc6149d95..650b2fbe1c142 100644
--- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
+++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala
@@ -159,6 +159,7 @@ class SparkSinkSuite extends FunSuite {
channelContext.put("transactionCapacity", 1000.toString)
channelContext.put("keep-alive", 0.toString)
channelContext.putAll(overrides)
+ channel.setName(scala.util.Random.nextString(10))
channel.configure(channelContext)
val sink = new SparkSink()
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 7d31e32283d88..a682f0e8471d8 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
@@ -39,19 +39,13 @@
org.apache.sparkspark-streaming_${scala.binary.version}${project.version}
+ providedorg.apache.sparkspark-streaming-flume-sink_${scala.binary.version}${project.version}
-
- org.apache.spark
- spark-streaming_${scala.binary.version}
- ${project.version}
- test-jar
- test
- org.apache.flumeflume-ng-sdk
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 4b2ea45fb81d0..2de2a7926bfd1 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -66,7 +66,7 @@ class SparkFlumeEvent() extends Externalizable {
var event : AvroFlumeEvent = new AvroFlumeEvent()
/* De-serialize from bytes. */
- def readExternal(in: ObjectInput) {
+ def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val bodyLength = in.readInt()
val bodyBuff = new Array[Byte](bodyLength)
in.readFully(bodyBuff)
@@ -93,7 +93,7 @@ class SparkFlumeEvent() extends Externalizable {
}
/* Serialize to bytes. */
- def writeExternal(out: ObjectOutput) {
+ def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
val body = event.getBody.array()
out.writeInt(body.length)
out.write(body)
diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
new file mode 100644
index 0000000000000..6e1f01900071b
--- /dev/null
+++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.junit.After;
+import org.junit.Before;
+
+public abstract class LocalJavaStreamingContext {
+
+ protected transient JavaStreamingContext ssc;
+
+ @Before
+ public void setUp() {
+ System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+ ssc.checkpoint("checkpoint");
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ }
+}
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala
new file mode 100644
index 0000000000000..1a900007b696b
--- /dev/null
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.{IOException, ObjectInputStream}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.dstream.{DStream, ForEachDStream}
+import org.apache.spark.util.Utils
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of items
+ */
+class TestOutputStream[T: ClassTag](parent: DStream[T],
+ val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]())
+ extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ val collected = rdd.collect()
+ output += collected
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
+ ois.defaultReadObject()
+ output.clear()
+ }
+}
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 32a19787a28e1..b57a1c71e35b9 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -20,9 +20,6 @@ package org.apache.spark.streaming.flume
import java.net.InetSocketAddress
import java.util.concurrent.{Callable, ExecutorCompletionService, Executors}
-import java.util.Random
-
-import org.apache.spark.TestUtils
import scala.collection.JavaConversions._
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
@@ -32,20 +29,35 @@ import org.apache.flume.channel.MemoryChannel
import org.apache.flume.conf.Configurables
import org.apache.flume.event.EventBuilder
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.util.ManualClock
-import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext}
+import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
import org.apache.spark.streaming.flume.sink._
import org.apache.spark.util.Utils
-class FlumePollingStreamSuite extends TestSuiteBase {
+class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging {
val batchCount = 5
val eventsPerBatch = 100
val totalEventsPerChannel = batchCount * eventsPerBatch
val channelCapacity = 5000
val maxAttempts = 5
+ val batchDuration = Seconds(1)
+
+ val conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName(this.getClass.getSimpleName)
+
+ def beforeFunction() {
+ logInfo("Using manual clock")
+ conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ }
+
+ before(beforeFunction())
test("flume polling test") {
testMultipleTimes(testFlumePolling)
@@ -145,11 +157,16 @@ class FlumePollingStreamSuite extends TestSuiteBase {
outputStream.register()
ssc.start()
- writeAndVerify(Seq(channel, channel2), ssc, outputBuffer)
- assertChannelIsEmpty(channel)
- assertChannelIsEmpty(channel2)
- sink.stop()
- channel.stop()
+ try {
+ writeAndVerify(Seq(channel, channel2), ssc, outputBuffer)
+ assertChannelIsEmpty(channel)
+ assertChannelIsEmpty(channel2)
+ } finally {
+ sink.stop()
+ sink2.stop()
+ channel.stop()
+ channel2.stop()
+ }
}
def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext,
@@ -224,4 +241,5 @@ class FlumePollingStreamSuite extends TestSuiteBase {
null
}
}
+
}
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
index 33235d150b4a5..13943ed5442b9 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala
@@ -17,103 +17,141 @@
package org.apache.spark.streaming.flume
-import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
-
-import java.net.InetSocketAddress
+import java.net.{InetSocketAddress, ServerSocket}
import java.nio.ByteBuffer
import java.nio.charset.Charset
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
+import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol}
+import org.jboss.netty.channel.ChannelPipeline
+import org.jboss.netty.channel.socket.SocketChannel
+import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
+import org.jboss.netty.handler.codec.compression._
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.concurrent.Eventually._
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuiteBase}
-import org.apache.spark.streaming.util.ManualClock
+import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
+import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted}
import org.apache.spark.util.Utils
-import org.jboss.netty.channel.ChannelPipeline
-import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
-import org.jboss.netty.channel.socket.SocketChannel
-import org.jboss.netty.handler.codec.compression._
+class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
+ val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
+
+ var ssc: StreamingContext = null
+ var transceiver: NettyTransceiver = null
-class FlumeStreamSuite extends TestSuiteBase {
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ }
+ if (transceiver != null) {
+ transceiver.close()
+ }
+ }
test("flume input stream") {
- runFlumeStreamTest(false)
+ testFlumeStream(testCompression = false)
}
test("flume input compressed stream") {
- runFlumeStreamTest(true)
+ testFlumeStream(testCompression = true)
+ }
+
+ /** Run test on flume stream */
+ private def testFlumeStream(testCompression: Boolean): Unit = {
+ val input = (1 to 100).map { _.toString }
+ val testPort = findFreePort()
+ val outputBuffer = startContext(testPort, testCompression)
+ writeAndVerify(input, testPort, outputBuffer, testCompression)
+ }
+
+ /** Find a free port */
+ private def findFreePort(): Int = {
+ Utils.startServiceOnPort(23456, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ })._2
}
-
- def runFlumeStreamTest(enableDecompression: Boolean) {
- // Set up the streaming context and input streams
- val ssc = new StreamingContext(conf, batchDuration)
- val (flumeStream, testPort) =
- Utils.startServiceOnPort(9997, (trialPort: Int) => {
- val dstream = FlumeUtils.createStream(
- ssc, "localhost", trialPort, StorageLevel.MEMORY_AND_DISK, enableDecompression)
- (dstream, trialPort)
- })
+ /** Setup and start the streaming context */
+ private def startContext(
+ testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = {
+ ssc = new StreamingContext(conf, Milliseconds(200))
+ val flumeStream = FlumeUtils.createStream(
+ ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression)
val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]]
with SynchronizedBuffer[Seq[SparkFlumeEvent]]
val outputStream = new TestOutputStream(flumeStream, outputBuffer)
outputStream.register()
ssc.start()
+ outputBuffer
+ }
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- val input = Seq(1, 2, 3, 4, 5)
- Thread.sleep(1000)
- val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort))
- var client: AvroSourceProtocol = null
-
- if (enableDecompression) {
- client = SpecificRequestor.getClient(
- classOf[AvroSourceProtocol],
- new NettyTransceiver(new InetSocketAddress("localhost", testPort),
- new CompressionChannelFactory(6)))
- } else {
- client = SpecificRequestor.getClient(
- classOf[AvroSourceProtocol], transceiver)
- }
+ /** Send data to the flume receiver and verify whether the data was received */
+ private def writeAndVerify(
+ input: Seq[String],
+ testPort: Int,
+ outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]],
+ enableCompression: Boolean
+ ) {
+ val testAddress = new InetSocketAddress("localhost", testPort)
- for (i <- 0 until input.size) {
+ val inputEvents = input.map { item =>
val event = new AvroFlumeEvent
- event.setBody(ByteBuffer.wrap(input(i).toString.getBytes("utf-8")))
+ event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8")))
event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header"))
- client.append(event)
- Thread.sleep(500)
- clock.addToTime(batchDuration.milliseconds)
+ event
}
- Thread.sleep(1000)
-
- val startTime = System.currentTimeMillis()
- while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
- logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size)
- Thread.sleep(100)
+ eventually(timeout(10 seconds), interval(100 milliseconds)) {
+ // if last attempted transceiver had succeeded, close it
+ if (transceiver != null) {
+ transceiver.close()
+ transceiver = null
+ }
+
+ // Create transceiver
+ transceiver = {
+ if (enableCompression) {
+ new NettyTransceiver(testAddress, new CompressionChannelFactory(6))
+ } else {
+ new NettyTransceiver(testAddress)
+ }
+ }
+
+ // Create Avro client with the transceiver
+ val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver)
+ client should not be null
+
+ // Send data
+ val status = client.appendBatch(inputEvents.toList)
+ status should be (avro.Status.OK)
}
- Thread.sleep(1000)
- val timeTaken = System.currentTimeMillis() - startTime
- assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
- logInfo("Stopping context")
- ssc.stop()
-
- val decoder = Charset.forName("UTF-8").newDecoder()
-
- assert(outputBuffer.size === input.length)
- for (i <- 0 until outputBuffer.size) {
- assert(outputBuffer(i).size === 1)
- val str = decoder.decode(outputBuffer(i).head.event.getBody)
- assert(str.toString === input(i).toString)
- assert(outputBuffer(i).head.event.getHeaders.get("test") === "header")
+
+ val decoder = Charset.forName("UTF-8").newDecoder()
+ eventually(timeout(10 seconds), interval(100 milliseconds)) {
+ val outputEvents = outputBuffer.flatten.map { _.event }
+ outputEvents.foreach {
+ event =>
+ event.getHeaders.get("test") should be("header")
+ }
+ val output = outputEvents.map(event => decoder.decode(event.getBody()).toString)
+ output should be (input)
}
}
- class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory {
+ /** Class to create socket channel with compression */
+ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory {
override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
val encoder = new ZlibEncoder(compressionLevel)
pipeline.addFirst("deflater", encoder)
diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index 2067c473f0e3f..b3f44471cd326 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
@@ -39,13 +39,7 @@
org.apache.sparkspark-streaming_${scala.binary.version}${project.version}
-
-
- org.apache.spark
- spark-streaming_${scala.binary.version}
- ${project.version}
- test-jar
- test
+ providedorg.apache.kafka
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
index e20e2c8f26991..4d26b640e8d74 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala
@@ -17,23 +17,21 @@
package org.apache.spark.streaming.kafka
+import java.util.Properties
+
import scala.collection.Map
import scala.reflect.{classTag, ClassTag}
-import java.util.Properties
-import java.util.concurrent.Executors
-
-import kafka.consumer._
+import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector}
import kafka.serializer.Decoder
import kafka.utils.VerifiableProperties
-import kafka.utils.ZKStringSerializer
-import org.I0Itec.zkclient._
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.Receiver
+import org.apache.spark.util.Utils
/**
* Input stream that pulls messages from a Kafka Broker.
@@ -53,12 +51,16 @@ class KafkaInputDStream[
@transient ssc_ : StreamingContext,
kafkaParams: Map[String, String],
topics: Map[String, Int],
+ useReliableReceiver: Boolean,
storageLevel: StorageLevel
) extends ReceiverInputDStream[(K, V)](ssc_) with Logging {
def getReceiver(): Receiver[(K, V)] = {
- new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel)
- .asInstanceOf[Receiver[(K, V)]]
+ if (!useReliableReceiver) {
+ new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel)
+ } else {
+ new ReliableKafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel)
+ }
}
}
@@ -71,14 +73,15 @@ class KafkaReceiver[
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
- ) extends Receiver[Any](storageLevel) with Logging {
+ ) extends Receiver[(K, V)](storageLevel) with Logging {
// Connection to Kafka
- var consumerConnector : ConsumerConnector = null
+ var consumerConnector: ConsumerConnector = null
def onStop() {
if (consumerConnector != null) {
consumerConnector.shutdown()
+ consumerConnector = null
}
}
@@ -97,12 +100,6 @@ class KafkaReceiver[
consumerConnector = Consumer.create(consumerConfig)
logInfo("Connected to " + zkConnect)
- // When auto.offset.reset is defined, it is our responsibility to try and whack the
- // consumer group zk node.
- if (kafkaParams.contains("auto.offset.reset")) {
- tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id"))
- }
-
val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties])
.newInstance(consumerConfig.props)
.asInstanceOf[Decoder[K]]
@@ -110,11 +107,11 @@ class KafkaReceiver[
.newInstance(consumerConfig.props)
.asInstanceOf[Decoder[V]]
- // Create Threads for each Topic/Message Stream we are listening
+ // Create threads for each topic/message Stream we are listening
val topicMessageStreams = consumerConnector.createMessageStreams(
topics, keyDecoder, valueDecoder)
- val executorPool = Executors.newFixedThreadPool(topics.values.sum)
+ val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler")
try {
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
@@ -125,13 +122,15 @@ class KafkaReceiver[
}
}
- // Handles Kafka Messages
- private class MessageHandler[K: ClassTag, V: ClassTag](stream: KafkaStream[K, V])
+ // Handles Kafka messages
+ private class MessageHandler(stream: KafkaStream[K, V])
extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
try {
- for (msgAndMetadata <- stream) {
+ val streamIterator = stream.iterator()
+ while (streamIterator.hasNext()) {
+ val msgAndMetadata = streamIterator.next()
store((msgAndMetadata.key, msgAndMetadata.message))
}
} catch {
@@ -139,26 +138,4 @@ class KafkaReceiver[
}
}
}
-
- // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This
- // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper.
- //
- // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied
- // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to
- // 'smallest'/'largest':
- // scalastyle:off
- // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala
- // scalastyle:on
- private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) {
- val dir = "/consumers/" + groupId
- logInfo("Cleaning up temporary Zookeeper data under " + dir + ".")
- val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer)
- try {
- zk.deleteRecursive(dir)
- } catch {
- case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e)
- } finally {
- zk.close()
- }
- }
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 48668f763e41e..b4ac929e0c070 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -17,19 +17,18 @@
package org.apache.spark.streaming.kafka
-import scala.reflect.ClassTag
-import scala.collection.JavaConversions._
-
import java.lang.{Integer => JInt}
import java.util.{Map => JMap}
+import scala.reflect.ClassTag
+import scala.collection.JavaConversions._
+
import kafka.serializer.{Decoder, StringDecoder}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
-import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext, JavaPairDStream}
-import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream}
-
+import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext}
+import org.apache.spark.streaming.dstream.ReceiverInputDStream
object KafkaUtils {
/**
@@ -71,7 +70,8 @@ object KafkaUtils {
topics: Map[String, Int],
storageLevel: StorageLevel
): ReceiverInputDStream[(K, V)] = {
- new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, storageLevel)
+ val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)
+ new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel)
}
/**
@@ -100,7 +100,6 @@ object KafkaUtils {
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel RDD storage level.
- *
*/
def createStream(
jssc: JavaStreamingContext,
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
new file mode 100644
index 0000000000000..be734b80272d1
--- /dev/null
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka
+
+import java.util.Properties
+import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap}
+
+import scala.collection.{Map, mutable}
+import scala.reflect.{ClassTag, classTag}
+
+import kafka.common.TopicAndPartition
+import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream}
+import kafka.message.MessageAndMetadata
+import kafka.serializer.Decoder
+import kafka.utils.{VerifiableProperties, ZKGroupTopicDirs, ZKStringSerializer, ZkUtils}
+import org.I0Itec.zkclient.ZkClient
+
+import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
+import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver}
+import org.apache.spark.util.Utils
+
+/**
+ * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss.
+ * It is turned off by default and will be enabled when
+ * spark.streaming.receiver.writeAheadLog.enable is true. The difference compared to KafkaReceiver
+ * is that this receiver manages topic-partition/offset itself and updates the offset information
+ * after data is reliably stored as write-ahead log. Offsets will only be updated when data is
+ * reliably stored, so the potential data loss problem of KafkaReceiver can be eliminated.
+ *
+ * Note: ReliableKafkaReceiver will set auto.commit.enable to false to turn off automatic offset
+ * commit mechanism in Kafka consumer. So setting this configuration manually within kafkaParams
+ * will not take effect.
+ */
+private[streaming]
+class ReliableKafkaReceiver[
+ K: ClassTag,
+ V: ClassTag,
+ U <: Decoder[_]: ClassTag,
+ T <: Decoder[_]: ClassTag](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel)
+ extends Receiver[(K, V)](storageLevel) with Logging {
+
+ private val groupId = kafkaParams("group.id")
+ private val AUTO_OFFSET_COMMIT = "auto.commit.enable"
+ private def conf = SparkEnv.get.conf
+
+ /** High level consumer to connect to Kafka. */
+ private var consumerConnector: ConsumerConnector = null
+
+ /** zkClient to connect to Zookeeper to commit the offsets. */
+ private var zkClient: ZkClient = null
+
+ /**
+ * A HashMap to manage the offset for each topic/partition, this HashMap is called in
+ * synchronized block, so mutable HashMap will not meet concurrency issue.
+ */
+ private var topicPartitionOffsetMap: mutable.HashMap[TopicAndPartition, Long] = null
+
+ /** A concurrent HashMap to store the stream block id and related offset snapshot. */
+ private var blockOffsetMap: ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]] = null
+
+ /**
+ * Manage the BlockGenerator in receiver itself for better managing block store and offset
+ * commit.
+ */
+ private var blockGenerator: BlockGenerator = null
+
+ /** Thread pool running the handlers for receiving message from multiple topics and partitions. */
+ private var messageHandlerThreadPool: ThreadPoolExecutor = null
+
+ override def onStart(): Unit = {
+ logInfo(s"Starting Kafka Consumer Stream with group: $groupId")
+
+ // Initialize the topic-partition / offset hash map.
+ topicPartitionOffsetMap = new mutable.HashMap[TopicAndPartition, Long]
+
+ // Initialize the stream block id / offset snapshot hash map.
+ blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]()
+
+ // Initialize the block generator for storing Kafka message.
+ blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf)
+
+ if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") {
+ logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " +
+ "otherwise we will manually set it to false to turn off auto offset commit in Kafka")
+ }
+
+ val props = new Properties()
+ kafkaParams.foreach(param => props.put(param._1, param._2))
+ // Manually set "auto.commit.enable" to "false" no matter user explicitly set it to true,
+ // we have to make sure this property is set to false to turn off auto commit mechanism in
+ // Kafka.
+ props.setProperty(AUTO_OFFSET_COMMIT, "false")
+
+ val consumerConfig = new ConsumerConfig(props)
+
+ assert(!consumerConfig.autoCommitEnable)
+
+ logInfo(s"Connecting to Zookeeper: ${consumerConfig.zkConnect}")
+ consumerConnector = Consumer.create(consumerConfig)
+ logInfo(s"Connected to Zookeeper: ${consumerConfig.zkConnect}")
+
+ zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs,
+ consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer)
+
+ messageHandlerThreadPool = Utils.newDaemonFixedThreadPool(
+ topics.values.sum, "KafkaMessageHandler")
+
+ blockGenerator.start()
+
+ val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[K]]
+
+ val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[V]]
+
+ val topicMessageStreams = consumerConnector.createMessageStreams(
+ topics, keyDecoder, valueDecoder)
+
+ topicMessageStreams.values.foreach { streams =>
+ streams.foreach { stream =>
+ messageHandlerThreadPool.submit(new MessageHandler(stream))
+ }
+ }
+ }
+
+ override def onStop(): Unit = {
+ if (messageHandlerThreadPool != null) {
+ messageHandlerThreadPool.shutdown()
+ messageHandlerThreadPool = null
+ }
+
+ if (consumerConnector != null) {
+ consumerConnector.shutdown()
+ consumerConnector = null
+ }
+
+ if (zkClient != null) {
+ zkClient.close()
+ zkClient = null
+ }
+
+ if (blockGenerator != null) {
+ blockGenerator.stop()
+ blockGenerator = null
+ }
+
+ if (topicPartitionOffsetMap != null) {
+ topicPartitionOffsetMap.clear()
+ topicPartitionOffsetMap = null
+ }
+
+ if (blockOffsetMap != null) {
+ blockOffsetMap.clear()
+ blockOffsetMap = null
+ }
+ }
+
+ /** Store a Kafka message and the associated metadata as a tuple. */
+ private def storeMessageAndMetadata(
+ msgAndMetadata: MessageAndMetadata[K, V]): Unit = {
+ val topicAndPartition = TopicAndPartition(msgAndMetadata.topic, msgAndMetadata.partition)
+ val data = (msgAndMetadata.key, msgAndMetadata.message)
+ val metadata = (topicAndPartition, msgAndMetadata.offset)
+ blockGenerator.addDataWithCallback(data, metadata)
+ }
+
+ /** Update stored offset */
+ private def updateOffset(topicAndPartition: TopicAndPartition, offset: Long): Unit = {
+ topicPartitionOffsetMap.put(topicAndPartition, offset)
+ }
+
+ /**
+ * Remember the current offsets for each topic and partition. This is called when a block is
+ * generated.
+ */
+ private def rememberBlockOffsets(blockId: StreamBlockId): Unit = {
+ // Get a snapshot of current offset map and store with related block id.
+ val offsetSnapshot = topicPartitionOffsetMap.toMap
+ blockOffsetMap.put(blockId, offsetSnapshot)
+ topicPartitionOffsetMap.clear()
+ }
+
+ /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */
+ private def storeBlockAndCommitOffset(
+ blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
+ store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]])
+ Option(blockOffsetMap.get(blockId)).foreach(commitOffset)
+ blockOffsetMap.remove(blockId)
+ }
+
+ /**
+ * Commit the offset of Kafka's topic/partition, the commit mechanism follow Kafka 0.8.x's
+ * metadata schema in Zookeeper.
+ */
+ private def commitOffset(offsetMap: Map[TopicAndPartition, Long]): Unit = {
+ if (zkClient == null) {
+ val thrown = new IllegalStateException("Zookeeper client is unexpectedly null")
+ stop("Zookeeper client is not initialized before commit offsets to ZK", thrown)
+ return
+ }
+
+ for ((topicAndPart, offset) <- offsetMap) {
+ try {
+ val topicDirs = new ZKGroupTopicDirs(groupId, topicAndPart.topic)
+ val zkPath = s"${topicDirs.consumerOffsetDir}/${topicAndPart.partition}"
+
+ ZkUtils.updatePersistentPath(zkClient, zkPath, offset.toString)
+ } catch {
+ case e: Exception =>
+ logWarning(s"Exception during commit offset $offset for topic" +
+ s"${topicAndPart.topic}, partition ${topicAndPart.partition}", e)
+ }
+
+ logInfo(s"Committed offset $offset for topic ${topicAndPart.topic}, " +
+ s"partition ${topicAndPart.partition}")
+ }
+ }
+
+ /** Class to handle received Kafka message. */
+ private final class MessageHandler(stream: KafkaStream[K, V]) extends Runnable {
+ override def run(): Unit = {
+ while (!isStopped) {
+ try {
+ val streamIterator = stream.iterator()
+ while (streamIterator.hasNext) {
+ storeMessageAndMetadata(streamIterator.next)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Error handling message", e)
+ }
+ }
+ }
+ }
+
+ /** Class to handle blocks generated by the block generator. */
+ private final class GeneratedBlockHandler extends BlockGeneratorListener {
+
+ def onAddData(data: Any, metadata: Any): Unit = {
+ // Update the offset of the data that was added to the generator
+ if (metadata != null) {
+ val (topicAndPartition, offset) = metadata.asInstanceOf[(TopicAndPartition, Long)]
+ updateOffset(topicAndPartition, offset)
+ }
+ }
+
+ def onGenerateBlock(blockId: StreamBlockId): Unit = {
+ // Remember the offsets of topics/partitions when a block has been generated
+ rememberBlockOffsets(blockId)
+ }
+
+ def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
+ // Store block and commit the blocks offset
+ storeBlockAndCommitOffset(blockId, arrayBuffer)
+ }
+
+ def onError(message: String, throwable: Throwable): Unit = {
+ reportError(message, throwable)
+ }
+ }
+}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index efb0099c7c850..6e1abf3f385ee 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -20,7 +20,10 @@
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
+import java.util.Random;
+import org.apache.spark.SparkConf;
+import org.apache.spark.streaming.Duration;
import scala.Predef;
import scala.Tuple2;
import scala.collection.JavaConverters;
@@ -32,8 +35,6 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.storage.StorageLevel;
-import org.apache.spark.streaming.Duration;
-import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
@@ -42,25 +43,27 @@
import org.junit.After;
import org.junit.Before;
-public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable {
- private transient KafkaStreamSuite testSuite = new KafkaStreamSuite();
+public class JavaKafkaStreamSuite implements Serializable {
+ private transient JavaStreamingContext ssc = null;
+ private transient Random random = new Random();
+ private transient KafkaStreamSuiteBase suiteBase = null;
@Before
- @Override
public void setUp() {
- testSuite.beforeFunction();
+ suiteBase = new KafkaStreamSuiteBase() { };
+ suiteBase.setupKafka();
System.clearProperty("spark.driver.port");
- //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock");
- ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+ SparkConf sparkConf = new SparkConf()
+ .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+ ssc = new JavaStreamingContext(sparkConf, new Duration(500));
}
@After
- @Override
public void tearDown() {
ssc.stop();
ssc = null;
System.clearProperty("spark.driver.port");
- testSuite.afterFunction();
+ suiteBase.tearDownKafka();
}
@Test
@@ -74,15 +77,15 @@ public void testKafkaStream() throws InterruptedException {
sent.put("b", 3);
sent.put("c", 10);
- testSuite.createTopic(topic);
+ suiteBase.createTopic(topic);
HashMap tmp = new HashMap(sent);
- testSuite.produceAndSendMessage(topic,
- JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
- Predef.>conforms()));
+ suiteBase.produceAndSendMessage(topic,
+ JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
+ Predef.>conforms()));
HashMap kafkaParams = new HashMap();
- kafkaParams.put("zookeeper.connect", testSuite.zkHost() + ":" + testSuite.zkPort());
- kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000));
+ kafkaParams.put("zookeeper.connect", suiteBase.zkAddress());
+ kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000));
kafkaParams.put("auto.offset.reset", "smallest");
JavaPairDStream stream = KafkaUtils.createStream(ssc,
@@ -124,11 +127,16 @@ public Void call(JavaPairRDD rdd) throws Exception {
);
ssc.start();
- ssc.awaitTermination(3000);
-
+ long startTime = System.currentTimeMillis();
+ boolean sizeMatches = false;
+ while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) {
+ sizeMatches = sent.size() == result.size();
+ Thread.sleep(200);
+ }
Assert.assertEquals(sent.size(), result.size());
for (String k : sent.keySet()) {
Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue());
}
+ ssc.stop();
}
}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index 6943326eb750e..b19c053ebfc44 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -19,51 +19,57 @@ package org.apache.spark.streaming.kafka
import java.io.File
import java.net.InetSocketAddress
-import java.util.{Properties, Random}
+import java.util.Properties
import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.util.Random
import kafka.admin.CreateTopicCommand
import kafka.common.{KafkaException, TopicAndPartition}
-import kafka.producer.{KeyedMessage, ProducerConfig, Producer}
-import kafka.utils.ZKStringSerializer
+import kafka.producer.{KeyedMessage, Producer, ProducerConfig}
import kafka.serializer.{StringDecoder, StringEncoder}
import kafka.server.{KafkaConfig, KafkaServer}
-
+import kafka.utils.ZKStringSerializer
import org.I0Itec.zkclient.ZkClient
+import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.concurrent.Eventually
-import org.apache.zookeeper.server.ZooKeeperServer
-import org.apache.zookeeper.server.NIOServerCnxnFactory
-
-import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Milliseconds, StreamingContext}
import org.apache.spark.util.Utils
-class KafkaStreamSuite extends TestSuiteBase {
- import KafkaTestUtils._
-
- val zkHost = "localhost"
- var zkPort: Int = 0
- val zkConnectionTimeout = 6000
- val zkSessionTimeout = 6000
-
- protected var brokerPort = 9092
- protected var brokerConf: KafkaConfig = _
- protected var zookeeper: EmbeddedZookeeper = _
- protected var zkClient: ZkClient = _
- protected var server: KafkaServer = _
- protected var producer: Producer[String, String] = _
-
- override def useManualClock = false
-
- override def beforeFunction() {
+/**
+ * This is an abstract base class for Kafka testsuites. This has the functionality to set up
+ * and tear down local Kafka servers, and to push data using Kafka producers.
+ */
+abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging {
+
+ var zkAddress: String = _
+ var zkClient: ZkClient = _
+
+ private val zkHost = "localhost"
+ private val zkConnectionTimeout = 6000
+ private val zkSessionTimeout = 6000
+ private var zookeeper: EmbeddedZookeeper = _
+ private var zkPort: Int = 0
+ private var brokerPort = 9092
+ private var brokerConf: KafkaConfig = _
+ private var server: KafkaServer = _
+ private var producer: Producer[String, String] = _
+
+ def setupKafka() {
// Zookeeper server startup
zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
// Get the actual zookeeper binding port
zkPort = zookeeper.actualPort
+ zkAddress = s"$zkHost:$zkPort"
logInfo("==================== 0 ====================")
- zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout,
+ zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout,
ZKStringSerializer)
logInfo("==================== 1 ====================")
@@ -71,7 +77,7 @@ class KafkaStreamSuite extends TestSuiteBase {
var bindSuccess: Boolean = false
while(!bindSuccess) {
try {
- val brokerProps = getBrokerConfig(brokerPort, s"$zkHost:$zkPort")
+ val brokerProps = getBrokerConfig()
brokerConf = new KafkaConfig(brokerProps)
server = new KafkaServer(brokerConf)
logInfo("==================== 2 ====================")
@@ -89,53 +95,30 @@ class KafkaStreamSuite extends TestSuiteBase {
Thread.sleep(2000)
logInfo("==================== 4 ====================")
- super.beforeFunction()
}
- override def afterFunction() {
- producer.close()
- server.shutdown()
- brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) }
-
- zkClient.close()
- zookeeper.shutdown()
-
- super.afterFunction()
- }
-
- test("Kafka input stream") {
- val ssc = new StreamingContext(master, framework, batchDuration)
- val topic = "topic1"
- val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
- createTopic(topic)
- produceAndSendMessage(topic, sent)
+ def tearDownKafka() {
+ if (producer != null) {
+ producer.close()
+ producer = null
+ }
- val kafkaParams = Map("zookeeper.connect" -> s"$zkHost:$zkPort",
- "group.id" -> s"test-consumer-${random.nextInt(10000)}",
- "auto.offset.reset" -> "smallest")
+ if (server != null) {
+ server.shutdown()
+ server = null
+ }
- val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
- ssc,
- kafkaParams,
- Map(topic -> 1),
- StorageLevel.MEMORY_ONLY)
- val result = new mutable.HashMap[String, Long]()
- stream.map { case (k, v) => v }
- .countByValue()
- .foreachRDD { r =>
- val ret = r.collect()
- ret.toMap.foreach { kv =>
- val count = result.getOrElseUpdate(kv._1, 0) + kv._2
- result.put(kv._1, count)
- }
- }
- ssc.start()
- ssc.awaitTermination(3000)
+ brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) }
- assert(sent.size === result.size)
- sent.keys.foreach { k => assert(sent(k) === result(k).toInt) }
+ if (zkClient != null) {
+ zkClient.close()
+ zkClient = null
+ }
- ssc.stop()
+ if (zookeeper != null) {
+ zookeeper.shutdown()
+ zookeeper = null
+ }
}
private def createTestMessage(topic: String, sent: Map[String, Int])
@@ -150,58 +133,43 @@ class KafkaStreamSuite extends TestSuiteBase {
CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0")
logInfo("==================== 5 ====================")
// wait until metadata is propagated
- waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000)
+ waitUntilMetadataIsPropagated(topic, 0)
}
def produceAndSendMessage(topic: String, sent: Map[String, Int]) {
- val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
- producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr)))
+ producer = new Producer[String, String](new ProducerConfig(getProducerConfig()))
producer.send(createTestMessage(topic, sent): _*)
+ producer.close()
logInfo("==================== 6 ====================")
}
-}
-
-object KafkaTestUtils {
- val random = new Random()
- def getBrokerConfig(port: Int, zkConnect: String): Properties = {
+ private def getBrokerConfig(): Properties = {
val props = new Properties()
props.put("broker.id", "0")
props.put("host.name", "localhost")
- props.put("port", port.toString)
+ props.put("port", brokerPort.toString)
props.put("log.dir", Utils.createTempDir().getAbsolutePath)
- props.put("zookeeper.connect", zkConnect)
+ props.put("zookeeper.connect", zkAddress)
props.put("log.flush.interval.messages", "1")
props.put("replica.socket.timeout.ms", "1500")
props
}
- def getProducerConfig(brokerList: String): Properties = {
+ private def getProducerConfig(): Properties = {
+ val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
val props = new Properties()
- props.put("metadata.broker.list", brokerList)
+ props.put("metadata.broker.list", brokerAddr)
props.put("serializer.class", classOf[StringEncoder].getName)
props
}
- def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = {
- val startTime = System.currentTimeMillis()
- while (true) {
- if (condition())
- return true
- if (System.currentTimeMillis() > startTime + waitTime)
- return false
- Thread.sleep(waitTime.min(100L))
+ private def waitUntilMetadataIsPropagated(topic: String, partition: Int) {
+ eventually(timeout(1000 milliseconds), interval(100 milliseconds)) {
+ assert(
+ server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)),
+ s"Partition [$topic, $partition] metadata not propagated after timeout"
+ )
}
- // Should never go to here
- throw new RuntimeException("unexpected error")
- }
-
- def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int,
- timeout: Long) {
- assert(waitUntilTrue(() =>
- servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains(
- TopicAndPartition(topic, partition))), timeout),
- s"Partition [$topic, $partition] metadata not propagated after timeout")
}
class EmbeddedZookeeper(val zkConnect: String) {
@@ -227,3 +195,53 @@ object KafkaTestUtils {
}
}
}
+
+
+class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
+ var ssc: StreamingContext = _
+
+ before {
+ setupKafka()
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ ssc = null
+ }
+ tearDownKafka()
+ }
+
+ test("Kafka input stream") {
+ val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
+ ssc = new StreamingContext(sparkConf, Milliseconds(500))
+ val topic = "topic1"
+ val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
+ createTopic(topic)
+ produceAndSendMessage(topic, sent)
+
+ val kafkaParams = Map("zookeeper.connect" -> zkAddress,
+ "group.id" -> s"test-consumer-${Random.nextInt(10000)}",
+ "auto.offset.reset" -> "smallest")
+
+ val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY)
+ val result = new mutable.HashMap[String, Long]()
+ stream.map(_._2).countByValue().foreachRDD { r =>
+ val ret = r.collect()
+ ret.toMap.foreach { kv =>
+ val count = result.getOrElseUpdate(kv._1, 0) + kv._2
+ result.put(kv._1, count)
+ }
+ }
+ ssc.start()
+ eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
+ assert(sent.size === result.size)
+ sent.keys.foreach { k =>
+ assert(sent(k) === result(k).toInt)
+ }
+ }
+ ssc.stop()
+ }
+}
+
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
new file mode 100644
index 0000000000000..64ccc92c81fa9
--- /dev/null
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka
+
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.util.Random
+
+import com.google.common.io.Files
+import kafka.serializer.StringDecoder
+import kafka.utils.{ZKGroupTopicDirs, ZkUtils}
+import org.apache.commons.io.FileUtils
+import org.scalatest.BeforeAndAfter
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Milliseconds, StreamingContext}
+
+class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually {
+
+ val sparkConf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.receiver.writeAheadLog.enable", "true")
+ val data = Map("a" -> 10, "b" -> 10, "c" -> 10)
+
+
+ var groupId: String = _
+ var kafkaParams: Map[String, String] = _
+ var ssc: StreamingContext = _
+ var tempDirectory: File = null
+
+ before {
+ setupKafka()
+ groupId = s"test-consumer-${Random.nextInt(10000)}"
+ kafkaParams = Map(
+ "zookeeper.connect" -> zkAddress,
+ "group.id" -> groupId,
+ "auto.offset.reset" -> "smallest"
+ )
+
+ ssc = new StreamingContext(sparkConf, Milliseconds(500))
+ tempDirectory = Files.createTempDir()
+ ssc.checkpoint(tempDirectory.getAbsolutePath)
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ }
+ if (tempDirectory != null && tempDirectory.exists()) {
+ FileUtils.deleteDirectory(tempDirectory)
+ tempDirectory = null
+ }
+ tearDownKafka()
+ }
+
+
+ test("Reliable Kafka input stream with single topic") {
+ var topic = "test-topic"
+ createTopic(topic)
+ produceAndSendMessage(topic, data)
+
+ // Verify whether the offset of this group/topic/partition is 0 before starting.
+ assert(getCommitOffset(groupId, topic, 0) === None)
+
+ val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY)
+ val result = new mutable.HashMap[String, Long]()
+ stream.map { case (k, v) => v }.foreachRDD { r =>
+ val ret = r.collect()
+ ret.foreach { v =>
+ val count = result.getOrElseUpdate(v, 0) + 1
+ result.put(v, count)
+ }
+ }
+ ssc.start()
+ eventually(timeout(20000 milliseconds), interval(200 milliseconds)) {
+ // A basic process verification for ReliableKafkaReceiver.
+ // Verify whether received message number is equal to the sent message number.
+ assert(data.size === result.size)
+ // Verify whether each message is the same as the data to be verified.
+ data.keys.foreach { k => assert(data(k) === result(k).toInt) }
+ // Verify the offset number whether it is equal to the total message number.
+ assert(getCommitOffset(groupId, topic, 0) === Some(29L))
+ }
+ ssc.stop()
+ }
+
+ test("Reliable Kafka input stream with multiple topics") {
+ val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1)
+ topics.foreach { case (t, _) =>
+ createTopic(t)
+ produceAndSendMessage(t, data)
+ }
+
+ // Before started, verify all the group/topic/partition offsets are 0.
+ topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === None) }
+
+ // Consuming all the data sent to the broker which will potential commit the offsets internally.
+ val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY)
+ stream.foreachRDD(_ => Unit)
+ ssc.start()
+ eventually(timeout(20000 milliseconds), interval(100 milliseconds)) {
+ // Verify the offset for each group/topic to see whether they are equal to the expected one.
+ topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) }
+ }
+ ssc.stop()
+ }
+
+
+ /** Getting partition offset from Zookeeper. */
+ private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = {
+ assert(zkClient != null, "Zookeeper client is not initialized")
+ val topicDirs = new ZKGroupTopicDirs(groupId, topic)
+ val zkPath = s"${topicDirs.consumerOffsetDir}/$partition"
+ ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong)
+ }
+}
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 371f1f1e9d39a..703806735b3ff 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
@@ -39,24 +39,13 @@
org.apache.sparkspark-streaming_${scala.binary.version}${project.version}
-
-
- org.apache.spark
- spark-streaming_${scala.binary.version}
- ${project.version}
- test-jar
- test
+ providedorg.eclipse.pahomqtt-client0.4.0
-
- ${akka.group}
- akka-zeromq_${scala.binary.version}
- ${akka.version}
- org.scalatestscalatest_${scala.binary.version}
diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
new file mode 100644
index 0000000000000..6e1f01900071b
--- /dev/null
+++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.junit.After;
+import org.junit.Before;
+
+public abstract class LocalJavaStreamingContext {
+
+ protected transient JavaStreamingContext ssc;
+
+ @Before
+ public void setUp() {
+ System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+ ssc.checkpoint("checkpoint");
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ }
+}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index 467fd263e2d64..84595acf45ccb 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -17,11 +17,19 @@
package org.apache.spark.streaming.mqtt
-import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
+import org.scalatest.FunSuite
+
+import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class MQTTStreamSuite extends TestSuiteBase {
+class MQTTStreamSuite extends FunSuite {
+
+ val batchDuration = Seconds(1)
+
+ private val master: String = "local[2]"
+
+ private val framework: String = this.getClass.getSimpleName
test("mqtt input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml
index 1d7dd49d15c22..000ace1446e5e 100644
--- a/external/twitter/pom.xml
+++ b/external/twitter/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
@@ -39,13 +39,7 @@
org.apache.sparkspark-streaming_${scala.binary.version}${project.version}
-
-
- org.apache.spark
- spark-streaming_${scala.binary.version}
- ${project.version}
- test-jar
- test
+ providedorg.twitter4j
diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
new file mode 100644
index 0000000000000..6e1f01900071b
--- /dev/null
+++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.junit.After;
+import org.junit.Before;
+
+public abstract class LocalJavaStreamingContext {
+
+ protected transient JavaStreamingContext ssc;
+
+ @Before
+ public void setUp() {
+ System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+ ssc.checkpoint("checkpoint");
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ }
+}
diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
index 93741e0375164..9ee57d7581d85 100644
--- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
+++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala
@@ -17,13 +17,23 @@
package org.apache.spark.streaming.twitter
-import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
-import org.apache.spark.storage.StorageLevel
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import twitter4j.Status
import twitter4j.auth.{NullAuthorization, Authorization}
+
+import org.apache.spark.Logging
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import twitter4j.Status
-class TwitterStreamSuite extends TestSuiteBase {
+class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging {
+
+ val batchDuration = Seconds(1)
+
+ private val master: String = "local[2]"
+
+ private val framework: String = this.getClass.getSimpleName
test("twitter input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml
index 7e48968feb3bc..29c452093502e 100644
--- a/external/zeromq/pom.xml
+++ b/external/zeromq/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
@@ -39,13 +39,7 @@
org.apache.sparkspark-streaming_${scala.binary.version}${project.version}
-
-
- org.apache.spark
- spark-streaming_${scala.binary.version}
- ${project.version}
- test-jar
- test
+ provided${akka.group}
diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
new file mode 100644
index 0000000000000..6e1f01900071b
--- /dev/null
+++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.junit.After;
+import org.junit.Before;
+
+public abstract class LocalJavaStreamingContext {
+
+ protected transient JavaStreamingContext ssc;
+
+ @Before
+ public void setUp() {
+ System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+ ssc.checkpoint("checkpoint");
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ }
+}
diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
index cc10ff6ae03cd..a7566e733d891 100644
--- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
+++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala
@@ -20,12 +20,19 @@ package org.apache.spark.streaming.zeromq
import akka.actor.SupervisorStrategy
import akka.util.ByteString
import akka.zeromq.Subscribe
+import org.scalatest.FunSuite
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
+import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-class ZeroMQStreamSuite extends TestSuiteBase {
+class ZeroMQStreamSuite extends FunSuite {
+
+ val batchDuration = Seconds(1)
+
+ private val master: String = "local[2]"
+
+ private val framework: String = this.getClass.getSimpleName
test("zeromq input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml
index 7e478bed62da7..c8477a6566311 100644
--- a/extras/java8-tests/pom.xml
+++ b/extras/java8-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index 560244ad93369..c0d3a61119113 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml
index 71a078d58a8d8..d1427f6a0c6e9 100644
--- a/extras/spark-ganglia-lgpl/pom.xml
+++ b/extras/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 3f49b1d63b6e1..9982b36f9b62f 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala
new file mode 100644
index 0000000000000..f70715fca6eea
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+/**
+ * Represents an edge along with its neighboring vertices and allows sending messages along the
+ * edge. Used in [[Graph#aggregateMessages]].
+ */
+abstract class EdgeContext[VD, ED, A] {
+ /** The vertex id of the edge's source vertex. */
+ def srcId: VertexId
+ /** The vertex id of the edge's destination vertex. */
+ def dstId: VertexId
+ /** The vertex attribute of the edge's source vertex. */
+ def srcAttr: VD
+ /** The vertex attribute of the edge's destination vertex. */
+ def dstAttr: VD
+ /** The attribute associated with the edge. */
+ def attr: ED
+
+ /** Sends a message to the source vertex. */
+ def sendToSrc(msg: A): Unit
+ /** Sends a message to the destination vertex. */
+ def sendToDst(msg: A): Unit
+
+ /** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */
+ def toEdgeTriplet: EdgeTriplet[VD, ED] = {
+ val et = new EdgeTriplet[VD, ED]
+ et.srcId = srcId
+ et.srcAttr = srcAttr
+ et.dstId = dstId
+ et.dstAttr = dstAttr
+ et.attr = attr
+ et
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index 5bcb96b136ed7..cc70b396a8dd4 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -17,14 +17,19 @@
package org.apache.spark.graphx
-import scala.reflect.{classTag, ClassTag}
+import scala.language.existentials
+import scala.reflect.ClassTag
-import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.Dependency
+import org.apache.spark.Partition
+import org.apache.spark.SparkContext
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.graphx.impl.EdgePartition
import org.apache.spark.graphx.impl.EdgePartitionBuilder
+import org.apache.spark.graphx.impl.EdgeRDDImpl
/**
* `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each
@@ -32,33 +37,16 @@ import org.apache.spark.graphx.impl.EdgePartitionBuilder
* edge to provide the triplet view. Shipping of the vertex attributes is managed by
* `impl.ReplicatedVertexView`.
*/
-class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag](
- val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])],
- val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
- extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+abstract class EdgeRDD[ED](
+ @transient sc: SparkContext,
+ @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) {
- override def setName(_name: String): this.type = {
- if (partitionsRDD.name != null) {
- partitionsRDD.setName(partitionsRDD.name + ", " + _name)
- } else {
- partitionsRDD.setName(_name)
- }
- this
- }
- setName("EdgeRDD")
+ private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD }
override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
- /**
- * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the
- * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new
- * partitioner that allows co-partitioning with `partitionsRDD`.
- */
- override val partitioner =
- partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
-
override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = {
- val p = firstParent[(PartitionID, EdgePartition[ED, VD])].iterator(part, context)
+ val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context)
if (p.hasNext) {
p.next._2.iterator.map(_.copy())
} else {
@@ -66,40 +54,6 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag](
}
}
- override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
-
- /**
- * Persists the edge partitions at the specified storage level, ignoring any existing target
- * storage level.
- */
- override def persist(newLevel: StorageLevel): this.type = {
- partitionsRDD.persist(newLevel)
- this
- }
-
- override def unpersist(blocking: Boolean = true): this.type = {
- partitionsRDD.unpersist(blocking)
- this
- }
-
- /** Persists the vertex partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
- override def cache(): this.type = {
- partitionsRDD.persist(targetStorageLevel)
- this
- }
-
- private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag](
- f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = {
- this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter =>
- if (iter.hasNext) {
- val (pid, ep) = iter.next()
- Iterator(Tuple2(pid, f(pid, ep)))
- } else {
- Iterator.empty
- }
- }, preservesPartitioning = true))
- }
-
/**
* Map the values in an edge partitioning preserving the structure but changing the values.
*
@@ -107,22 +61,14 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag](
* @param f the function from an edge to a new edge value
* @return a new EdgeRDD containing the new edge values
*/
- def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] =
- mapEdgePartitions((pid, part) => part.map(f))
+ def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2]
/**
* Reverse all the edges in this RDD.
*
* @return a new EdgeRDD containing all the edges reversed
*/
- def reverse: EdgeRDD[ED, VD] = mapEdgePartitions((pid, part) => part.reverse)
-
- /** Removes all edges but those matching `epred` and where both vertices match `vpred`. */
- def filter(
- epred: EdgeTriplet[VD, ED] => Boolean,
- vpred: (VertexId, VD) => Boolean): EdgeRDD[ED, VD] = {
- mapEdgePartitions((pid, part) => part.filter(epred, vpred))
- }
+ def reverse: EdgeRDD[ED]
/**
* Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same
@@ -134,23 +80,8 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag](
* with values supplied by `f`
*/
def innerJoin[ED2: ClassTag, ED3: ClassTag]
- (other: EdgeRDD[ED2, _])
- (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] = {
- val ed2Tag = classTag[ED2]
- val ed3Tag = classTag[ED3]
- this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) {
- (thisIter, otherIter) =>
- val (pid, thisEPart) = thisIter.next()
- val (_, otherEPart) = otherIter.next()
- Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag)))
- })
- }
-
- /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */
- private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag](
- partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDD[ED2, VD2] = {
- new EdgeRDD(partitionsRDD, this.targetStorageLevel)
- }
+ (other: EdgeRDD[ED2])
+ (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3]
/**
* Changes the target storage level while preserving all other properties of the
@@ -159,11 +90,7 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag](
* This does not actually trigger a cache; to do this, call
* [[org.apache.spark.graphx.EdgeRDD#cache]] on the returned EdgeRDD.
*/
- private[graphx] def withTargetStorageLevel(
- targetStorageLevel: StorageLevel): EdgeRDD[ED, VD] = {
- new EdgeRDD(this.partitionsRDD, targetStorageLevel)
- }
-
+ private[graphx] def withTargetStorageLevel(targetStorageLevel: StorageLevel): EdgeRDD[ED]
}
object EdgeRDD {
@@ -173,7 +100,7 @@ object EdgeRDD {
* @tparam ED the edge attribute type
* @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD
*/
- def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDD[ED, VD] = {
+ def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDDImpl[ED, VD] = {
val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) =>
val builder = new EdgePartitionBuilder[ED, VD]
iter.foreach { e =>
@@ -190,8 +117,8 @@ object EdgeRDD {
* @tparam ED the edge attribute type
* @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD
*/
- def fromEdgePartitions[ED: ClassTag, VD: ClassTag](
- edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDD[ED, VD] = {
- new EdgeRDD(edgePartitions)
+ private[graphx] def fromEdgePartitions[ED: ClassTag, VD: ClassTag](
+ edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDDImpl[ED, VD] = {
+ new EdgeRDDImpl(edgePartitions)
}
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index fa4b891754c40..637791543514c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* along with their vertex data.
*
*/
- @transient val edges: EdgeRDD[ED, VD]
+ @transient val edges: EdgeRDD[ED]
/**
* An RDD containing the edge triplets, which are edges along with the vertex data associated with
@@ -208,7 +208,37 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
*
*/
def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
- mapTriplets((pid, iter) => iter.map(map))
+ mapTriplets((pid, iter) => iter.map(map), TripletFields.All)
+ }
+
+ /**
+ * Transforms each edge attribute using the map function, passing it the adjacent vertex
+ * attributes as well. If adjacent vertex values are not required,
+ * consider using `mapEdges` instead.
+ *
+ * @note This does not change the structure of the
+ * graph or modify the values of this graph. As a consequence
+ * the underlying index structures can be reused.
+ *
+ * @param map the function from an edge object to a new edge value.
+ * @param tripletFields which fields should be included in the edge triplet passed to the map
+ * function. If not all fields are needed, specifying this can improve performance.
+ *
+ * @tparam ED2 the new edge data type
+ *
+ * @example This function might be used to initialize edge
+ * attributes based on the attributes associated with each vertex.
+ * {{{
+ * val rawGraph: Graph[Int, Int] = someLoadFunction()
+ * val graph = rawGraph.mapTriplets[Int]( edge =>
+ * edge.src.data - edge.dst.data)
+ * }}}
+ *
+ */
+ def mapTriplets[ED2: ClassTag](
+ map: EdgeTriplet[VD, ED] => ED2,
+ tripletFields: TripletFields): Graph[VD, ED2] = {
+ mapTriplets((pid, iter) => iter.map(map), tripletFields)
}
/**
@@ -223,12 +253,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* the underlying index structures can be reused.
*
* @param map the iterator transform
+ * @param tripletFields which fields should be included in the edge triplet passed to the map
+ * function. If not all fields are needed, specifying this can improve performance.
*
* @tparam ED2 the new edge data type
*
*/
- def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
- : Graph[VD, ED2]
+ def mapTriplets[ED2: ClassTag](
+ map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
+ tripletFields: TripletFields): Graph[VD, ED2]
/**
* Reverses all edges in the graph. If this graph contains an edge from a to b then the returned
@@ -287,6 +320,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of
* the map phase destined to each vertex.
*
+ * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead.
+ *
* @tparam A the type of "message" to be sent to each vertex
*
* @param mapFunc the user defined map function which returns 0 or
@@ -296,13 +331,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* be commutative and associative and is used to combine the output
* of the map phase
*
- * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to
- * consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on
- * edges with destination in the active set. If the direction is `Out`,
- * `mapFunc` will only be run on edges originating from vertices in the active set. If the
- * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set
- * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the
- * active set. The active set must have the same index as the graph's vertices.
+ * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
+ * desired. This is done by specifying a set of "active" vertices and an edge direction. The
+ * `sendMsg` function will then run only on edges connected to active vertices by edges in the
+ * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
+ * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
+ * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
+ * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
+ * will be run on edges with *both* vertices in the active set. The active set must have the
+ * same index as the graph's vertices.
*
* @example We can use this function to compute the in-degree of each
* vertex
@@ -319,6 +356,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
* predicate or implement PageRank.
*
*/
+ @deprecated("use aggregateMessages", "1.2.0")
def mapReduceTriplets[A: ClassTag](
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
@@ -326,8 +364,80 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
: VertexRDD[A]
/**
- * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The
- * input table should contain at most one entry for each vertex. If no entry in `other` is
+ * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
+ * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
+ * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
+ * destined to the same vertex.
+ *
+ * @tparam A the type of message to be sent to each vertex
+ *
+ * @param sendMsg runs on each edge, sending messages to neighboring vertices using the
+ * [[EdgeContext]].
+ * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
+ * combiner should be commutative and associative.
+ * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
+ * `sendMsg` function. If not all fields are needed, specifying this can improve performance.
+ *
+ * @example We can use this function to compute the in-degree of each
+ * vertex
+ * {{{
+ * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph")
+ * val inDeg: RDD[(VertexId, Int)] =
+ * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _)
+ * }}}
+ *
+ * @note By expressing computation at the edge level we achieve
+ * maximum parallelism. This is one of the core functions in the
+ * Graph API in that enables neighborhood level computation. For
+ * example this function can be used to count neighbors satisfying a
+ * predicate or implement PageRank.
+ *
+ */
+ def aggregateMessages[A: ClassTag](
+ sendMsg: EdgeContext[VD, ED, A] => Unit,
+ mergeMsg: (A, A) => A,
+ tripletFields: TripletFields = TripletFields.All)
+ : VertexRDD[A] = {
+ aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None)
+ }
+
+ /**
+ * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
+ * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
+ * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
+ * destined to the same vertex.
+ *
+ * This variant can take an active set to restrict the computation and is intended for internal
+ * use only.
+ *
+ * @tparam A the type of message to be sent to each vertex
+ *
+ * @param sendMsg runs on each edge, sending messages to neighboring vertices using the
+ * [[EdgeContext]].
+ * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
+ * combiner should be commutative and associative.
+ * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
+ * `sendMsg` function. If not all fields are needed, specifying this can improve performance.
+ * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
+ * desired. This is done by specifying a set of "active" vertices and an edge direction. The
+ * `sendMsg` function will then run on only edges connected to active vertices by edges in the
+ * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
+ * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
+ * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
+ * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
+ * will be run on edges with *both* vertices in the active set. The active set must have the
+ * same index as the graph's vertices.
+ */
+ private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag](
+ sendMsg: EdgeContext[VD, ED, A] => Unit,
+ mergeMsg: (A, A) => A,
+ tripletFields: TripletFields,
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)])
+ : VertexRDD[A]
+
+ /**
+ * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`.
+ * The input table should contain at most one entry for each vertex. If no entry in `other` is
* provided for a particular vertex in the graph, the map function receives `None`.
*
* @tparam U the type of entry in the table of updates
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index 1948c978c30bf..563c948957ecf 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -27,10 +27,10 @@ import org.apache.spark.graphx.impl._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
import org.apache.spark.util.collection.OpenHashSet
-
/**
* Registers GraphX classes with Kryo for improved performance.
*/
+@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0")
class GraphKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
index f4c79365b16da..4933aecba1286 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
@@ -48,7 +48,8 @@ object GraphLoader extends Logging {
* @param path the path to the file (e.g., /home/data/file or hdfs://file)
* @param canonicalOrientation whether to orient edges in the positive
* direction
- * @param minEdgePartitions the number of partitions for the edge RDD
+ * @param numEdgePartitions the number of partitions for the edge RDD
+ * Setting this value to -1 will use the default parallelism.
* @param edgeStorageLevel the desired storage level for the edge partitions
* @param vertexStorageLevel the desired storage level for the vertex partitions
*/
@@ -56,7 +57,7 @@ object GraphLoader extends Logging {
sc: SparkContext,
path: String,
canonicalOrientation: Boolean = false,
- minEdgePartitions: Int = 1,
+ numEdgePartitions: Int = -1,
edgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY,
vertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
: Graph[Int, Int] =
@@ -64,7 +65,12 @@ object GraphLoader extends Logging {
val startTime = System.currentTimeMillis
// Parse the edge data table directly into edge partitions
- val lines = sc.textFile(path, minEdgePartitions).coalesce(minEdgePartitions)
+ val lines =
+ if (numEdgePartitions > 0) {
+ sc.textFile(path, numEdgePartitions).coalesce(numEdgePartitions)
+ } else {
+ sc.textFile(path)
+ }
val edges = lines.mapPartitionsWithIndex { (pid, iter) =>
val builder = new EdgePartitionBuilder[Int, Int]
iter.foreach { line =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index d0dd45dba618e..116d1ea700175 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
*/
private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
if (edgeDirection == EdgeDirection.In) {
- graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
+ graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None)
} else if (edgeDirection == EdgeDirection.Out) {
- graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
+ graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None)
} else { // EdgeDirection.Either
- graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
+ graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _,
+ TripletFields.None)
}
}
@@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = {
val nbrs =
if (edgeDirection == EdgeDirection.Either) {
- graph.mapReduceTriplets[Array[VertexId]](
- mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
- reduceFunc = _ ++ _
- )
+ graph.aggregateMessages[Array[VertexId]](
+ ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) },
+ _ ++ _, TripletFields.None)
} else if (edgeDirection == EdgeDirection.Out) {
- graph.mapReduceTriplets[Array[VertexId]](
- mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
- reduceFunc = _ ++ _)
+ graph.aggregateMessages[Array[VertexId]](
+ ctx => ctx.sendToSrc(Array(ctx.dstId)),
+ _ ++ _, TripletFields.None)
} else if (edgeDirection == EdgeDirection.In) {
- graph.mapReduceTriplets[Array[VertexId]](
- mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
- reduceFunc = _ ++ _)
+ graph.aggregateMessages[Array[VertexId]](
+ ctx => ctx.sendToDst(Array(ctx.srcId)),
+ _ ++ _, TripletFields.None)
} else {
throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
"direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
@@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
* @return the vertex set of neighboring vertex attributes for each vertex
*/
def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
- val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]](
- edge => {
- val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
- val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
- edgeDirection match {
- case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
- case EdgeDirection.In => Iterator(msgToDst)
- case EdgeDirection.Out => Iterator(msgToSrc)
- case EdgeDirection.Both =>
- throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
- "EdgeDirection.Either instead.")
- }
- },
- (a, b) => a ++ b)
-
- graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
+ val nbrs = edgeDirection match {
+ case EdgeDirection.Either =>
+ graph.aggregateMessages[Array[(VertexId,VD)]](
+ ctx => {
+ ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
+ ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
+ },
+ (a, b) => a ++ b, TripletFields.All)
+ case EdgeDirection.In =>
+ graph.aggregateMessages[Array[(VertexId,VD)]](
+ ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
+ (a, b) => a ++ b, TripletFields.Src)
+ case EdgeDirection.Out =>
+ graph.aggregateMessages[Array[(VertexId,VD)]](
+ ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
+ (a, b) => a ++ b, TripletFields.Dst)
+ case EdgeDirection.Both =>
+ throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
+ "EdgeDirection.Either instead.")
+ }
+ graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
}
} // end of collectNeighbor
@@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
edgeDirection match {
case EdgeDirection.Either =>
- graph.mapReduceTriplets[Array[Edge[ED]]](
- edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
- (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
- (a, b) => a ++ b)
+ graph.aggregateMessages[Array[Edge[ED]]](
+ ctx => {
+ ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
+ ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
+ },
+ (a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.In =>
- graph.mapReduceTriplets[Array[Edge[ED]]](
- edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
- (a, b) => a ++ b)
+ graph.aggregateMessages[Array[Edge[ED]]](
+ ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
+ (a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.Out =>
- graph.mapReduceTriplets[Array[Edge[ED]]](
- edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
- (a, b) => a ++ b)
+ graph.aggregateMessages[Array[Edge[ED]]](
+ ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
+ (a, b) => a ++ b, TripletFields.EdgeOnly)
case EdgeDirection.Both =>
throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
"EdgeDirection.Either instead.")
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
new file mode 100644
index 0000000000000..2cb07937eaa2a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkConf
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+
+import org.apache.spark.util.collection.{OpenHashSet, BitSet}
+import org.apache.spark.util.BoundedPriorityQueue
+
+object GraphXUtils {
+ /**
+ * Registers classes that GraphX uses with Kryo.
+ */
+ def registerKryoClasses(conf: SparkConf) {
+ conf.registerKryoClasses(Array(
+ classOf[Edge[Object]],
+ classOf[(VertexId, Object)],
+ classOf[EdgePartition[Object, Object]],
+ classOf[BitSet],
+ classOf[VertexIdToIndexMap],
+ classOf[VertexAttributeBlock[Object]],
+ classOf[PartitionStrategy],
+ classOf[BoundedPriorityQueue[Object]],
+ classOf[EdgeDirection],
+ classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]],
+ classOf[OpenHashSet[Int]],
+ classOf[OpenHashSet[Long]]))
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java
new file mode 100644
index 0000000000000..7eb4ae0f44602
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx;
+
+import java.io.Serializable;
+
+/**
+ * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the
+ * system to populate only those fields for efficiency.
+ */
+public class TripletFields implements Serializable {
+
+ /** Indicates whether the source vertex attribute is included. */
+ public final boolean useSrc;
+
+ /** Indicates whether the destination vertex attribute is included. */
+ public final boolean useDst;
+
+ /** Indicates whether the edge attribute is included. */
+ public final boolean useEdge;
+
+ /** Constructs a default TripletFields in which all fields are included. */
+ public TripletFields() {
+ this(true, true, true);
+ }
+
+ public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) {
+ this.useSrc = useSrc;
+ this.useDst = useDst;
+ this.useEdge = useEdge;
+ }
+
+ /**
+ * None of the triplet fields are exposed.
+ */
+ public static final TripletFields None = new TripletFields(false, false, false);
+
+ /**
+ * Expose only the edge field and not the source or destination field.
+ */
+ public static final TripletFields EdgeOnly = new TripletFields(false, false, true);
+
+ /**
+ * Expose the source and edge fields but not the destination field. (Same as Src)
+ */
+ public static final TripletFields Src = new TripletFields(true, false, true);
+
+ /**
+ * Expose the destination and edge fields but not the source field. (Same as Dst)
+ */
+ public static final TripletFields Dst = new TripletFields(false, true, true);
+
+ /**
+ * Expose all the fields (source, edge, and destination).
+ */
+ public static final TripletFields All = new TripletFields(true, true, true);
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
index 2c8b245955d12..1db3df03c8052 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala
@@ -27,8 +27,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.graphx.impl.RoutingTablePartition
import org.apache.spark.graphx.impl.ShippableVertexPartition
import org.apache.spark.graphx.impl.VertexAttributeBlock
-import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._
-import org.apache.spark.graphx.impl.VertexRDDFunctions._
+import org.apache.spark.graphx.impl.VertexRDDImpl
/**
* Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by
@@ -55,62 +54,16 @@ import org.apache.spark.graphx.impl.VertexRDDFunctions._
*
* @tparam VD the vertex attribute associated with each vertex in the set.
*/
-class VertexRDD[@specialized VD: ClassTag](
- val partitionsRDD: RDD[ShippableVertexPartition[VD]],
- val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
- extends RDD[(VertexId, VD)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+abstract class VertexRDD[VD](
+ @transient sc: SparkContext,
+ @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) {
- require(partitionsRDD.partitioner.isDefined)
+ implicit protected def vdTag: ClassTag[VD]
- /**
- * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting
- * VertexRDD will be based on a different index and can no longer be quickly joined with this
- * RDD.
- */
- def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex()))
-
- override val partitioner = partitionsRDD.partitioner
+ private[graphx] def partitionsRDD: RDD[ShippableVertexPartition[VD]]
override protected def getPartitions: Array[Partition] = partitionsRDD.partitions
- override protected def getPreferredLocations(s: Partition): Seq[String] =
- partitionsRDD.preferredLocations(s)
-
- override def setName(_name: String): this.type = {
- if (partitionsRDD.name != null) {
- partitionsRDD.setName(partitionsRDD.name + ", " + _name)
- } else {
- partitionsRDD.setName(_name)
- }
- this
- }
- setName("VertexRDD")
-
- /**
- * Persists the vertex partitions at the specified storage level, ignoring any existing target
- * storage level.
- */
- override def persist(newLevel: StorageLevel): this.type = {
- partitionsRDD.persist(newLevel)
- this
- }
-
- override def unpersist(blocking: Boolean = true): this.type = {
- partitionsRDD.unpersist(blocking)
- this
- }
-
- /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */
- override def cache(): this.type = {
- partitionsRDD.persist(targetStorageLevel)
- this
- }
-
- /** The number of vertices in the RDD. */
- override def count(): Long = {
- partitionsRDD.map(_.size.toLong).reduce(_ + _)
- }
-
/**
* Provides the `RDD[(VertexId, VD)]` equivalent output.
*/
@@ -118,22 +71,28 @@ class VertexRDD[@specialized VD: ClassTag](
firstParent[ShippableVertexPartition[VD]].iterator(part, context).next.iterator
}
+ /**
+ * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting
+ * VertexRDD will be based on a different index and can no longer be quickly joined with this
+ * RDD.
+ */
+ def reindex(): VertexRDD[VD]
+
/**
* Applies a function to each `VertexPartition` of this RDD and returns a new VertexRDD.
*/
private[graphx] def mapVertexPartitions[VD2: ClassTag](
f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2])
- : VertexRDD[VD2] = {
- val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true)
- this.withPartitionsRDD(newPartitionsRDD)
- }
-
+ : VertexRDD[VD2]
/**
* Restricts the vertex set to the set of vertices satisfying the given predicate. This operation
* preserves the index for efficient joins with the original RDD, and it sets bits in the bitmask
* rather than allocating new memory.
*
+ * It is declared and defined here to allow refining the return type from `RDD[(VertexId, VD)]` to
+ * `VertexRDD[VD]`.
+ *
* @param pred the user defined predicate, which takes a tuple to conform to the
* `RDD[(VertexId, VD)]` interface
*/
@@ -149,8 +108,7 @@ class VertexRDD[@specialized VD: ClassTag](
* @return a new VertexRDD with values obtained by applying `f` to each of the entries in the
* original VertexRDD
*/
- def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] =
- this.mapVertexPartitions(_.map((vid, attr) => f(attr)))
+ def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2]
/**
* Maps each vertex attribute, additionally supplying the vertex ID.
@@ -161,23 +119,13 @@ class VertexRDD[@specialized VD: ClassTag](
* @return a new VertexRDD with values obtained by applying `f` to each of the entries in the
* original VertexRDD. The resulting VertexRDD retains the same index.
*/
- def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] =
- this.mapVertexPartitions(_.map(f))
+ def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2]
/**
* Hides vertices that are the same between `this` and `other`; for vertices that are different,
* keeps the values from `other`.
*/
- def diff(other: VertexRDD[VD]): VertexRDD[VD] = {
- val newPartitionsRDD = partitionsRDD.zipPartitions(
- other.partitionsRDD, preservesPartitioning = true
- ) { (thisIter, otherIter) =>
- val thisPart = thisIter.next()
- val otherPart = otherIter.next()
- Iterator(thisPart.diff(otherPart))
- }
- this.withPartitionsRDD(newPartitionsRDD)
- }
+ def diff(other: VertexRDD[VD]): VertexRDD[VD]
/**
* Left joins this RDD with another VertexRDD with the same index. This function will fail if
@@ -194,16 +142,7 @@ class VertexRDD[@specialized VD: ClassTag](
* @return a VertexRDD containing the results of `f`
*/
def leftZipJoin[VD2: ClassTag, VD3: ClassTag]
- (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = {
- val newPartitionsRDD = partitionsRDD.zipPartitions(
- other.partitionsRDD, preservesPartitioning = true
- ) { (thisIter, otherIter) =>
- val thisPart = thisIter.next()
- val otherPart = otherIter.next()
- Iterator(thisPart.leftJoin(otherPart)(f))
- }
- this.withPartitionsRDD(newPartitionsRDD)
- }
+ (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3]
/**
* Left joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is
@@ -224,37 +163,14 @@ class VertexRDD[@specialized VD: ClassTag](
def leftJoin[VD2: ClassTag, VD3: ClassTag]
(other: RDD[(VertexId, VD2)])
(f: (VertexId, VD, Option[VD2]) => VD3)
- : VertexRDD[VD3] = {
- // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
- // If the other set is a VertexRDD then we use the much more efficient leftZipJoin
- other match {
- case other: VertexRDD[_] =>
- leftZipJoin(other)(f)
- case _ =>
- this.withPartitionsRDD[VD3](
- partitionsRDD.zipPartitions(
- other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
- (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f))
- }
- )
- }
- }
+ : VertexRDD[VD3]
/**
* Efficiently inner joins this VertexRDD with another VertexRDD sharing the same index. See
* [[innerJoin]] for the behavior of the join.
*/
def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U])
- (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = {
- val newPartitionsRDD = partitionsRDD.zipPartitions(
- other.partitionsRDD, preservesPartitioning = true
- ) { (thisIter, otherIter) =>
- val thisPart = thisIter.next()
- val otherPart = otherIter.next()
- Iterator(thisPart.innerJoin(otherPart)(f))
- }
- this.withPartitionsRDD(newPartitionsRDD)
- }
+ (f: (VertexId, VD, U) => VD2): VertexRDD[VD2]
/**
* Inner joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is
@@ -268,21 +184,7 @@ class VertexRDD[@specialized VD: ClassTag](
* `this` and `other`, with values supplied by `f`
*/
def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])
- (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = {
- // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
- // If the other set is a VertexRDD then we use the much more efficient innerZipJoin
- other match {
- case other: VertexRDD[_] =>
- innerZipJoin(other)(f)
- case _ =>
- this.withPartitionsRDD(
- partitionsRDD.zipPartitions(
- other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) {
- (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f))
- }
- )
- }
- }
+ (f: (VertexId, VD, U) => VD2): VertexRDD[VD2]
/**
* Aggregates vertices in `messages` that have the same ids using `reduceFunc`, returning a
@@ -296,38 +198,20 @@ class VertexRDD[@specialized VD: ClassTag](
* messages.
*/
def aggregateUsingIndex[VD2: ClassTag](
- messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
- val shuffled = messages.copartitionWithVertices(this.partitioner.get)
- val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
- thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc))
- }
- this.withPartitionsRDD[VD2](parts)
- }
+ messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2]
/**
* Returns a new `VertexRDD` reflecting a reversal of all edge directions in the corresponding
* [[EdgeRDD]].
*/
- def reverseRoutingTables(): VertexRDD[VD] =
- this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse))
+ def reverseRoutingTables(): VertexRDD[VD]
/** Prepares this VertexRDD for efficient joins with the given EdgeRDD. */
- def withEdges(edges: EdgeRDD[_, _]): VertexRDD[VD] = {
- val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get)
- val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) {
- (partIter, routingTableIter) =>
- val routingTable =
- if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty
- partIter.map(_.withRoutingTable(routingTable))
- }
- this.withPartitionsRDD(vertexPartitions)
- }
+ def withEdges(edges: EdgeRDD[_]): VertexRDD[VD]
/** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */
private[graphx] def withPartitionsRDD[VD2: ClassTag](
- partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = {
- new VertexRDD(partitionsRDD, this.targetStorageLevel)
- }
+ partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2]
/**
* Changes the target storage level while preserving all other properties of the
@@ -337,20 +221,14 @@ class VertexRDD[@specialized VD: ClassTag](
* [[org.apache.spark.graphx.VertexRDD#cache]] on the returned VertexRDD.
*/
private[graphx] def withTargetStorageLevel(
- targetStorageLevel: StorageLevel): VertexRDD[VD] = {
- new VertexRDD(this.partitionsRDD, targetStorageLevel)
- }
+ targetStorageLevel: StorageLevel): VertexRDD[VD]
/** Generates an RDD of vertex attributes suitable for shipping to the edge partitions. */
private[graphx] def shipVertexAttributes(
- shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = {
- partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst)))
- }
+ shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])]
/** Generates an RDD of vertex IDs suitable for shipping to the edge partitions. */
- private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = {
- partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds()))
- }
+ private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])]
} // end of VertexRDD
@@ -371,12 +249,12 @@ object VertexRDD {
def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
- case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
+ case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
}
val vertexPartitions = vPartitioned.mapPartitions(
iter => Iterator(ShippableVertexPartition(iter)),
preservesPartitioning = true)
- new VertexRDD(vertexPartitions)
+ new VertexRDDImpl(vertexPartitions)
}
/**
@@ -391,7 +269,7 @@ object VertexRDD {
* @param defaultVal the vertex attribute to use when creating missing vertices
*/
def apply[VD: ClassTag](
- vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = {
+ vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD): VertexRDD[VD] = {
VertexRDD(vertices, edges, defaultVal, (a, b) => a)
}
@@ -408,11 +286,11 @@ object VertexRDD {
* @param mergeFunc the commutative, associative duplicate vertex attribute merge function
*/
def apply[VD: ClassTag](
- vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD
+ vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD, mergeFunc: (VD, VD) => VD
): VertexRDD[VD] = {
val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match {
case Some(p) => vertices
- case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size))
+ case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size))
}
val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get)
val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) {
@@ -421,7 +299,7 @@ object VertexRDD {
if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty
Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal, mergeFunc))
}
- new VertexRDD(vertexPartitions)
+ new VertexRDDImpl(vertexPartitions)
}
/**
@@ -436,25 +314,25 @@ object VertexRDD {
* @param defaultVal the vertex attribute to use when creating missing vertices
*/
def fromEdges[VD: ClassTag](
- edges: EdgeRDD[_, _], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = {
+ edges: EdgeRDD[_], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = {
val routingTables = createRoutingTables(edges, new HashPartitioner(numPartitions))
val vertexPartitions = routingTables.mapPartitions({ routingTableIter =>
val routingTable =
if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty
Iterator(ShippableVertexPartition(Iterator.empty, routingTable, defaultVal))
}, preservesPartitioning = true)
- new VertexRDD(vertexPartitions)
+ new VertexRDDImpl(vertexPartitions)
}
- private def createRoutingTables(
- edges: EdgeRDD[_, _], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = {
+ private[graphx] def createRoutingTables(
+ edges: EdgeRDD[_], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = {
// Determine which vertices each edge partition needs by creating a mapping from vid to pid.
val vid2pid = edges.partitionsRDD.mapPartitions(_.flatMap(
Function.tupled(RoutingTablePartition.edgePartitionToMsgs)))
.setName("VertexRDD.createRoutingTables - vid2pid (aggregation)")
val numEdgePartitions = edges.partitions.size
- vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions(
+ vid2pid.partitionBy(vertexPartitioner).mapPartitions(
iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)),
preservesPartitioning = true)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java
new file mode 100644
index 0000000000000..377ae849f045c
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx.impl;
+
+/**
+ * Criteria for filtering edges based on activeness. For internal use only.
+ */
+public enum EdgeActiveness {
+ /** Neither the source vertex nor the destination vertex need be active. */
+ Neither,
+ /** The source vertex must be active. */
+ SrcOnly,
+ /** The destination vertex must be active. */
+ DstOnly,
+ /** Both vertices must be active. */
+ Both,
+ /** At least one vertex must be active. */
+ Either
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
index a5c9cd1f8b4e6..373af75448374 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -21,63 +21,94 @@ import scala.reflect.{classTag, ClassTag}
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.BitSet
/**
- * A collection of edges stored in columnar format, along with any vertex attributes referenced. The
- * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by
- * src. There is an optional active vertex set for filtering computation on the edges.
+ * A collection of edges, along with referenced vertex attributes and an optional active vertex set
+ * for filtering computation on the edges.
+ *
+ * The edges are stored in columnar format in `localSrcIds`, `localDstIds`, and `data`. All
+ * referenced global vertex ids are mapped to a compact set of local vertex ids according to the
+ * `global2local` map. Each local vertex id is a valid index into `vertexAttrs`, which stores the
+ * corresponding vertex attribute, and `local2global`, which stores the reverse mapping to global
+ * vertex id. The global vertex ids that are active are optionally stored in `activeSet`.
+ *
+ * The edges are clustered by source vertex id, and the mapping from global vertex id to the index
+ * of the corresponding edge cluster is stored in `index`.
*
* @tparam ED the edge attribute type
* @tparam VD the vertex attribute type
*
- * @param srcIds the source vertex id of each edge
- * @param dstIds the destination vertex id of each edge
+ * @param localSrcIds the local source vertex id of each edge as an index into `local2global` and
+ * `vertexAttrs`
+ * @param localDstIds the local destination vertex id of each edge as an index into `local2global`
+ * and `vertexAttrs`
* @param data the attribute associated with each edge
- * @param index a clustered index on source vertex id
- * @param vertices a map from referenced vertex ids to their corresponding attributes. Must
- * contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for
- * those vertex ids. The mask is not used.
+ * @param index a clustered index on source vertex id as a map from each global source vertex id to
+ * the offset in the edge arrays where the cluster for that vertex id begins
+ * @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs
+ * @param local2global an array of global vertex ids where the offsets are local vertex ids
+ * @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids
* @param activeSet an optional active vertex set for filtering computation on the edges
*/
private[graphx]
class EdgePartition[
@specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag](
- val srcIds: Array[VertexId] = null,
- val dstIds: Array[VertexId] = null,
- val data: Array[ED] = null,
- val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null,
- val vertices: VertexPartition[VD] = null,
- val activeSet: Option[VertexSet] = None
- ) extends Serializable {
+ localSrcIds: Array[Int],
+ localDstIds: Array[Int],
+ data: Array[ED],
+ index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int],
+ global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int],
+ local2global: Array[VertexId],
+ vertexAttrs: Array[VD],
+ activeSet: Option[VertexSet])
+ extends Serializable {
- /** Return a new `EdgePartition` with the specified edge data. */
- def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = {
- new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet)
- }
+ /** No-arg constructor for serialization. */
+ private def this() = this(null, null, null, null, null, null, null, null)
- /** Return a new `EdgePartition` with the specified vertex partition. */
- def withVertices[VD2: ClassTag](
- vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = {
- new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet)
+ /** Return a new `EdgePartition` with the specified edge data. */
+ def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = {
+ new EdgePartition(
+ localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
}
/** Return a new `EdgePartition` with the specified active set, provided as an iterator. */
def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = {
- val newActiveSet = new VertexSet
- iter.foreach(newActiveSet.add(_))
- new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet))
- }
-
- /** Return a new `EdgePartition` with the specified active set. */
- def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = {
- new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_)
+ val activeSet = new VertexSet
+ while (iter.hasNext) { activeSet.add(iter.next()) }
+ new EdgePartition(
+ localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs,
+ Some(activeSet))
}
/** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */
def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = {
- this.withVertices(vertices.innerJoinKeepLeft(iter))
+ val newVertexAttrs = new Array[VD](vertexAttrs.length)
+ System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length)
+ while (iter.hasNext) {
+ val kv = iter.next()
+ newVertexAttrs(global2local(kv._1)) = kv._2
+ }
+ new EdgePartition(
+ localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
+ activeSet)
}
+ /** Return a new `EdgePartition` without any locally cached vertex attributes. */
+ def withoutVertexAttributes[VD2: ClassTag](): EdgePartition[ED, VD2] = {
+ val newVertexAttrs = new Array[VD2](vertexAttrs.length)
+ new EdgePartition(
+ localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
+ activeSet)
+ }
+
+ @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos))
+
+ @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos))
+
+ @inline private def attrs(pos: Int): ED = data(pos)
+
/** Look up vid in activeSet, throwing an exception if it is None. */
def isActive(vid: VertexId): Boolean = {
activeSet.get.contains(vid)
@@ -92,11 +123,19 @@ class EdgePartition[
* @return a new edge partition with all edges reversed.
*/
def reverse: EdgePartition[ED, VD] = {
- val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD])
- for (e <- iterator) {
- builder.add(e.dstId, e.srcId, e.attr)
+ val builder = new ExistingEdgePartitionBuilder[ED, VD](
+ global2local, local2global, vertexAttrs, activeSet, size)
+ var i = 0
+ while (i < size) {
+ val localSrcId = localSrcIds(i)
+ val localDstId = localDstIds(i)
+ val srcId = local2global(localSrcId)
+ val dstId = local2global(localDstId)
+ val attr = data(i)
+ builder.add(dstId, srcId, localDstId, localSrcId, attr)
+ i += 1
}
- builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+ builder.toEdgePartition
}
/**
@@ -157,13 +196,25 @@ class EdgePartition[
def filter(
epred: EdgeTriplet[VD, ED] => Boolean,
vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = {
- val filtered = tripletIterator().filter(et =>
- vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et))
- val builder = new EdgePartitionBuilder[ED, VD]
- for (e <- filtered) {
- builder.add(e.srcId, e.dstId, e.attr)
+ val builder = new ExistingEdgePartitionBuilder[ED, VD](
+ global2local, local2global, vertexAttrs, activeSet)
+ var i = 0
+ while (i < size) {
+ // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge.
+ val localSrcId = localSrcIds(i)
+ val localDstId = localDstIds(i)
+ val et = new EdgeTriplet[VD, ED]
+ et.srcId = local2global(localSrcId)
+ et.dstId = local2global(localDstId)
+ et.srcAttr = vertexAttrs(localSrcId)
+ et.dstAttr = vertexAttrs(localDstId)
+ et.attr = data(i)
+ if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) {
+ builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr)
+ }
+ i += 1
}
- builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+ builder.toEdgePartition
}
/**
@@ -183,28 +234,40 @@ class EdgePartition[
* @return a new edge partition without duplicate edges
*/
def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = {
- val builder = new EdgePartitionBuilder[ED, VD]
+ val builder = new ExistingEdgePartitionBuilder[ED, VD](
+ global2local, local2global, vertexAttrs, activeSet)
var currSrcId: VertexId = null.asInstanceOf[VertexId]
var currDstId: VertexId = null.asInstanceOf[VertexId]
+ var currLocalSrcId = -1
+ var currLocalDstId = -1
var currAttr: ED = null.asInstanceOf[ED]
+ // Iterate through the edges, accumulating runs of identical edges using the curr* variables and
+ // releasing them to the builder when we see the beginning of the next run
var i = 0
while (i < size) {
if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) {
+ // This edge should be accumulated into the existing run
currAttr = merge(currAttr, data(i))
} else {
+ // This edge starts a new run of edges
if (i > 0) {
- builder.add(currSrcId, currDstId, currAttr)
+ // First release the existing run to the builder
+ builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr)
}
+ // Then start accumulating for a new run
currSrcId = srcIds(i)
currDstId = dstIds(i)
+ currLocalSrcId = localSrcIds(i)
+ currLocalDstId = localDstIds(i)
currAttr = data(i)
}
i += 1
}
+ // Finally, release the last accumulated run
if (size > 0) {
- builder.add(currSrcId, currDstId, currAttr)
+ builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr)
}
- builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+ builder.toEdgePartition
}
/**
@@ -220,7 +283,8 @@ class EdgePartition[
def innerJoin[ED2: ClassTag, ED3: ClassTag]
(other: EdgePartition[ED2, _])
(f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = {
- val builder = new EdgePartitionBuilder[ED3, VD]
+ val builder = new ExistingEdgePartitionBuilder[ED3, VD](
+ global2local, local2global, vertexAttrs, activeSet)
var i = 0
var j = 0
// For i = index of each edge in `this`...
@@ -233,12 +297,13 @@ class EdgePartition[
while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 }
if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) {
// ... run `f` on the matching edge
- builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j)))
+ builder.add(srcId, dstId, localSrcIds(i), localDstIds(i),
+ f(srcId, dstId, this.data(i), other.attrs(j)))
}
}
i += 1
}
- builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+ builder.toEdgePartition
}
/**
@@ -246,7 +311,7 @@ class EdgePartition[
*
* @return size of the partition
*/
- val size: Int = srcIds.size
+ val size: Int = localSrcIds.size
/** The number of unique source vertices in the partition. */
def indexSize: Int = index.size
@@ -280,55 +345,198 @@ class EdgePartition[
* It is safe to keep references to the objects from this iterator.
*/
def tripletIterator(
- includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = {
- new EdgeTripletIterator(this, includeSrc, includeDst)
+ includeSrc: Boolean = true, includeDst: Boolean = true)
+ : Iterator[EdgeTriplet[VD, ED]] = new Iterator[EdgeTriplet[VD, ED]] {
+ private[this] var pos = 0
+
+ override def hasNext: Boolean = pos < EdgePartition.this.size
+
+ override def next() = {
+ val triplet = new EdgeTriplet[VD, ED]
+ val localSrcId = localSrcIds(pos)
+ val localDstId = localDstIds(pos)
+ triplet.srcId = local2global(localSrcId)
+ triplet.dstId = local2global(localDstId)
+ if (includeSrc) {
+ triplet.srcAttr = vertexAttrs(localSrcId)
+ }
+ if (includeDst) {
+ triplet.dstAttr = vertexAttrs(localDstId)
+ }
+ triplet.attr = data(pos)
+ pos += 1
+ triplet
+ }
}
/**
- * Upgrade the given edge iterator into a triplet iterator.
+ * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning
+ * all edges sequentially.
*
- * Be careful not to keep references to the objects from this iterator. To improve GC performance
- * the same object is re-used in `next()`.
+ * @param sendMsg generates messages to neighboring vertices of an edge
+ * @param mergeMsg the combiner applied to messages destined to the same vertex
+ * @param tripletFields which triplet fields `sendMsg` uses
+ * @param activeness criteria for filtering edges based on activeness
+ *
+ * @return iterator aggregated messages keyed by the receiving vertex id
*/
- def upgradeIterator(
- edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true)
- : Iterator[EdgeTriplet[VD, ED]] = {
- new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst)
+ def aggregateMessagesEdgeScan[A: ClassTag](
+ sendMsg: EdgeContext[VD, ED, A] => Unit,
+ mergeMsg: (A, A) => A,
+ tripletFields: TripletFields,
+ activeness: EdgeActiveness): Iterator[(VertexId, A)] = {
+ val aggregates = new Array[A](vertexAttrs.length)
+ val bitset = new BitSet(vertexAttrs.length)
+
+ var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset)
+ var i = 0
+ while (i < size) {
+ val localSrcId = localSrcIds(i)
+ val srcId = local2global(localSrcId)
+ val localDstId = localDstIds(i)
+ val dstId = local2global(localDstId)
+ val edgeIsActive =
+ if (activeness == EdgeActiveness.Neither) true
+ else if (activeness == EdgeActiveness.SrcOnly) isActive(srcId)
+ else if (activeness == EdgeActiveness.DstOnly) isActive(dstId)
+ else if (activeness == EdgeActiveness.Both) isActive(srcId) && isActive(dstId)
+ else if (activeness == EdgeActiveness.Either) isActive(srcId) || isActive(dstId)
+ else throw new Exception("unreachable")
+ if (edgeIsActive) {
+ val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD]
+ val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]
+ ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i))
+ sendMsg(ctx)
+ }
+ i += 1
+ }
+
+ bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
}
/**
- * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
- * iterator is generated using an index scan, so it is efficient at skipping edges that don't
- * match srcIdPred.
+ * Send messages along edges and aggregate them at the receiving vertices. Implemented by
+ * filtering the source vertex index, then scanning each edge cluster.
*
- * Be careful not to keep references to the objects from this iterator. To improve GC performance
- * the same object is re-used in `next()`.
- */
- def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] =
- index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
-
- /**
- * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
- * cluster must start at position `index`.
+ * @param sendMsg generates messages to neighboring vertices of an edge
+ * @param mergeMsg the combiner applied to messages destined to the same vertex
+ * @param tripletFields which triplet fields `sendMsg` uses
+ * @param activeness criteria for filtering edges based on activeness
*
- * Be careful not to keep references to the objects from this iterator. To improve GC performance
- * the same object is re-used in `next()`.
+ * @return iterator aggregated messages keyed by the receiving vertex id
*/
- private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] {
- private[this] val edge = new Edge[ED]
- private[this] var pos = index
+ def aggregateMessagesIndexScan[A: ClassTag](
+ sendMsg: EdgeContext[VD, ED, A] => Unit,
+ mergeMsg: (A, A) => A,
+ tripletFields: TripletFields,
+ activeness: EdgeActiveness): Iterator[(VertexId, A)] = {
+ val aggregates = new Array[A](vertexAttrs.length)
+ val bitset = new BitSet(vertexAttrs.length)
+
+ var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset)
+ index.iterator.foreach { cluster =>
+ val clusterSrcId = cluster._1
+ val clusterPos = cluster._2
+ val clusterLocalSrcId = localSrcIds(clusterPos)
- override def hasNext: Boolean = {
- pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
+ val scanCluster =
+ if (activeness == EdgeActiveness.Neither) true
+ else if (activeness == EdgeActiveness.SrcOnly) isActive(clusterSrcId)
+ else if (activeness == EdgeActiveness.DstOnly) true
+ else if (activeness == EdgeActiveness.Both) isActive(clusterSrcId)
+ else if (activeness == EdgeActiveness.Either) true
+ else throw new Exception("unreachable")
+
+ if (scanCluster) {
+ var pos = clusterPos
+ val srcAttr =
+ if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD]
+ ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr)
+ while (pos < size && localSrcIds(pos) == clusterLocalSrcId) {
+ val localDstId = localDstIds(pos)
+ val dstId = local2global(localDstId)
+ val edgeIsActive =
+ if (activeness == EdgeActiveness.Neither) true
+ else if (activeness == EdgeActiveness.SrcOnly) true
+ else if (activeness == EdgeActiveness.DstOnly) isActive(dstId)
+ else if (activeness == EdgeActiveness.Both) isActive(dstId)
+ else if (activeness == EdgeActiveness.Either) isActive(clusterSrcId) || isActive(dstId)
+ else throw new Exception("unreachable")
+ if (edgeIsActive) {
+ val dstAttr =
+ if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]
+ ctx.setRest(dstId, localDstId, dstAttr, data(pos))
+ sendMsg(ctx)
+ }
+ pos += 1
+ }
+ }
}
- override def next(): Edge[ED] = {
- assert(srcIds(pos) == srcId)
- edge.srcId = srcIds(pos)
- edge.dstId = dstIds(pos)
- edge.attr = data(pos)
- pos += 1
- edge
+ bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
+ }
+}
+
+private class AggregatingEdgeContext[VD, ED, A](
+ mergeMsg: (A, A) => A,
+ aggregates: Array[A],
+ bitset: BitSet)
+ extends EdgeContext[VD, ED, A] {
+
+ private[this] var _srcId: VertexId = _
+ private[this] var _dstId: VertexId = _
+ private[this] var _localSrcId: Int = _
+ private[this] var _localDstId: Int = _
+ private[this] var _srcAttr: VD = _
+ private[this] var _dstAttr: VD = _
+ private[this] var _attr: ED = _
+
+ def set(
+ srcId: VertexId, dstId: VertexId,
+ localSrcId: Int, localDstId: Int,
+ srcAttr: VD, dstAttr: VD,
+ attr: ED) {
+ _srcId = srcId
+ _dstId = dstId
+ _localSrcId = localSrcId
+ _localDstId = localDstId
+ _srcAttr = srcAttr
+ _dstAttr = dstAttr
+ _attr = attr
+ }
+
+ def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) {
+ _srcId = srcId
+ _localSrcId = localSrcId
+ _srcAttr = srcAttr
+ }
+
+ def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) {
+ _dstId = dstId
+ _localDstId = localDstId
+ _dstAttr = dstAttr
+ _attr = attr
+ }
+
+ override def srcId = _srcId
+ override def dstId = _dstId
+ override def srcAttr = _srcAttr
+ override def dstAttr = _dstAttr
+ override def attr = _attr
+
+ override def sendToSrc(msg: A) {
+ send(_localSrcId, msg)
+ }
+ override def sendToDst(msg: A) {
+ send(_localDstId, msg)
+ }
+
+ @inline private def send(localId: Int, msg: A) {
+ if (bitset.get(localId)) {
+ aggregates(localId) = mergeMsg(aggregates(localId), msg)
+ } else {
+ aggregates(localId) = msg
+ bitset.set(localId)
}
}
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
index 4520beb991515..b0cb0fe47d461 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -25,10 +25,11 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector}
import org.apache.spark.graphx._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+/** Constructs an EdgePartition from scratch. */
private[graphx]
class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag](
size: Int = 64) {
- var edges = new PrimitiveVector[Edge[ED]](size)
+ private[this] val edges = new PrimitiveVector[Edge[ED]](size)
/** Add a new edge to the partition. */
def add(src: VertexId, dst: VertexId, d: ED) {
@@ -38,19 +39,78 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
def toEdgePartition: EdgePartition[ED, VD] = {
val edgeArray = edges.trim().array
Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering)
- val srcIds = new Array[VertexId](edgeArray.size)
- val dstIds = new Array[VertexId](edgeArray.size)
+ val localSrcIds = new Array[Int](edgeArray.size)
+ val localDstIds = new Array[Int](edgeArray.size)
+ val data = new Array[ED](edgeArray.size)
+ val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
+ val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
+ val local2global = new PrimitiveVector[VertexId]
+ var vertexAttrs = Array.empty[VD]
+ // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
+ // adding them to the index. Also populate a map from vertex id to a sequential local offset.
+ if (edgeArray.length > 0) {
+ index.update(edgeArray(0).srcId, 0)
+ var currSrcId: VertexId = edgeArray(0).srcId
+ var currLocalId = -1
+ var i = 0
+ while (i < edgeArray.size) {
+ val srcId = edgeArray(i).srcId
+ val dstId = edgeArray(i).dstId
+ localSrcIds(i) = global2local.changeValue(srcId,
+ { currLocalId += 1; local2global += srcId; currLocalId }, identity)
+ localDstIds(i) = global2local.changeValue(dstId,
+ { currLocalId += 1; local2global += dstId; currLocalId }, identity)
+ data(i) = edgeArray(i).attr
+ if (srcId != currSrcId) {
+ currSrcId = srcId
+ index.update(currSrcId, i)
+ }
+
+ i += 1
+ }
+ vertexAttrs = new Array[VD](currLocalId + 1)
+ }
+ new EdgePartition(
+ localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs,
+ None)
+ }
+}
+
+/**
+ * Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables
+ * reuse of the local vertex ids. Intended for internal use in EdgePartition only.
+ */
+private[impl]
+class ExistingEdgePartitionBuilder[
+ @specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag](
+ global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int],
+ local2global: Array[VertexId],
+ vertexAttrs: Array[VD],
+ activeSet: Option[VertexSet],
+ size: Int = 64) {
+ private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size)
+
+ /** Add a new edge to the partition. */
+ def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) {
+ edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d)
+ }
+
+ def toEdgePartition: EdgePartition[ED, VD] = {
+ val edgeArray = edges.trim().array
+ Sorting.quickSort(edgeArray)(EdgeWithLocalIds.lexicographicOrdering)
+ val localSrcIds = new Array[Int](edgeArray.size)
+ val localDstIds = new Array[Int](edgeArray.size)
val data = new Array[ED](edgeArray.size)
val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
// Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
// adding them to the index
if (edgeArray.length > 0) {
- index.update(srcIds(0), 0)
- var currSrcId: VertexId = srcIds(0)
+ index.update(edgeArray(0).srcId, 0)
+ var currSrcId: VertexId = edgeArray(0).srcId
var i = 0
while (i < edgeArray.size) {
- srcIds(i) = edgeArray(i).srcId
- dstIds(i) = edgeArray(i).dstId
+ localSrcIds(i) = edgeArray(i).localSrcId
+ localDstIds(i) = edgeArray(i).localDstId
data(i) = edgeArray(i).attr
if (edgeArray(i).srcId != currSrcId) {
currSrcId = edgeArray(i).srcId
@@ -60,13 +120,24 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
}
}
- // Create and populate a VertexPartition with vids from the edges, but no attributes
- val vidsIter = srcIds.iterator ++ dstIds.iterator
- val vertexIds = new OpenHashSet[VertexId]
- vidsIter.foreach(vid => vertexIds.add(vid))
- val vertices = new VertexPartition(
- vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet)
+ new EdgePartition(
+ localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
+ }
+}
- new EdgePartition(srcIds, dstIds, data, index, vertices)
+private[impl] case class EdgeWithLocalIds[@specialized ED](
+ srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED)
+
+private[impl] object EdgeWithLocalIds {
+ implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] {
+ override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = {
+ if (a.srcId == b.srcId) {
+ if (a.dstId == b.dstId) 0
+ else if (a.dstId < b.dstId) -1
+ else 1
+ } else if (a.srcId < b.srcId) -1
+ else 1
+ }
}
+
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
new file mode 100644
index 0000000000000..a8169613b4fd2
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx.impl
+
+import scala.reflect.{classTag, ClassTag}
+
+import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+import org.apache.spark.graphx._
+
+class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
+ override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])],
+ val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
+ extends EdgeRDD[ED](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+
+ override def setName(_name: String): this.type = {
+ if (partitionsRDD.name != null) {
+ partitionsRDD.setName(partitionsRDD.name + ", " + _name)
+ } else {
+ partitionsRDD.setName(_name)
+ }
+ this
+ }
+ setName("EdgeRDD")
+
+ /**
+ * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the
+ * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new
+ * partitioner that allows co-partitioning with `partitionsRDD`.
+ */
+ override val partitioner =
+ partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
+
+ override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
+
+ /**
+ * Persists the edge partitions at the specified storage level, ignoring any existing target
+ * storage level.
+ */
+ override def persist(newLevel: StorageLevel): this.type = {
+ partitionsRDD.persist(newLevel)
+ this
+ }
+
+ override def unpersist(blocking: Boolean = true): this.type = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
+ /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
+ override def cache(): this.type = {
+ partitionsRDD.persist(targetStorageLevel)
+ this
+ }
+
+ /** The number of edges in the RDD. */
+ override def count(): Long = {
+ partitionsRDD.map(_._2.size.toLong).reduce(_ + _)
+ }
+
+ override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDDImpl[ED2, VD] =
+ mapEdgePartitions((pid, part) => part.map(f))
+
+ override def reverse: EdgeRDDImpl[ED, VD] = mapEdgePartitions((pid, part) => part.reverse)
+
+ def filter(
+ epred: EdgeTriplet[VD, ED] => Boolean,
+ vpred: (VertexId, VD) => Boolean): EdgeRDDImpl[ED, VD] = {
+ mapEdgePartitions((pid, part) => part.filter(epred, vpred))
+ }
+
+ override def innerJoin[ED2: ClassTag, ED3: ClassTag]
+ (other: EdgeRDD[ED2])
+ (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDDImpl[ED3, VD] = {
+ val ed2Tag = classTag[ED2]
+ val ed3Tag = classTag[ED3]
+ this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) {
+ (thisIter, otherIter) =>
+ val (pid, thisEPart) = thisIter.next()
+ val (_, otherEPart) = otherIter.next()
+ Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag)))
+ })
+ }
+
+ def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag](
+ f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDDImpl[ED2, VD2] = {
+ this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter =>
+ if (iter.hasNext) {
+ val (pid, ep) = iter.next()
+ Iterator(Tuple2(pid, f(pid, ep)))
+ } else {
+ Iterator.empty
+ }
+ }, preservesPartitioning = true))
+ }
+
+ private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag](
+ partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDDImpl[ED2, VD2] = {
+ new EdgeRDDImpl(partitionsRDD, this.targetStorageLevel)
+ }
+
+ override private[graphx] def withTargetStorageLevel(
+ targetStorageLevel: StorageLevel): EdgeRDDImpl[ED, VD] = {
+ new EdgeRDDImpl(this.partitionsRDD, targetStorageLevel)
+ }
+
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
deleted file mode 100644
index 56f79a7097fce..0000000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
-
-/**
- * The Iterator type returned when constructing edge triplets. This could be an anonymous class in
- * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile.
- */
-private[impl]
-class EdgeTripletIterator[VD: ClassTag, ED: ClassTag](
- val edgePartition: EdgePartition[ED, VD],
- val includeSrc: Boolean,
- val includeDst: Boolean)
- extends Iterator[EdgeTriplet[VD, ED]] {
-
- // Current position in the array.
- private var pos = 0
-
- override def hasNext: Boolean = pos < edgePartition.size
-
- override def next() = {
- val triplet = new EdgeTriplet[VD, ED]
- triplet.srcId = edgePartition.srcIds(pos)
- if (includeSrc) {
- triplet.srcAttr = edgePartition.vertices(triplet.srcId)
- }
- triplet.dstId = edgePartition.dstIds(pos)
- if (includeDst) {
- triplet.dstAttr = edgePartition.vertices(triplet.dstId)
- }
- triplet.attr = edgePartition.data(pos)
- pos += 1
- triplet
- }
-}
-
-/**
- * An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous
- * class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug /
- * profile.
- */
-private[impl]
-class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag](
- val edgeIter: Iterator[Edge[ED]],
- val edgePartition: EdgePartition[ED, VD],
- val includeSrc: Boolean,
- val includeDst: Boolean)
- extends Iterator[EdgeTriplet[VD, ED]] {
-
- private val triplet = new EdgeTriplet[VD, ED]
-
- override def hasNext = edgeIter.hasNext
-
- override def next() = {
- triplet.set(edgeIter.next())
- if (includeSrc) {
- triplet.srcAttr = edgePartition.vertices(triplet.srcId)
- }
- if (includeDst) {
- triplet.dstAttr = edgePartition.vertices(triplet.dstId)
- }
- triplet
- }
-}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 33f35cfb69a26..0eae2a673874a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -23,7 +23,6 @@ import org.apache.spark.HashPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.storage.StorageLevel
-
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl._
import org.apache.spark.graphx.util.BytecodeUtils
@@ -44,7 +43,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
/** Default constructor is provided to support serialization */
protected def this() = this(null, null)
- @transient override val edges: EdgeRDD[ED, VD] = replicatedVertexView.edges
+ @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges
/** Return a RDD that brings edges together with their source and destination vertices. */
@transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = {
@@ -127,13 +126,12 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
}
override def mapTriplets[ED2: ClassTag](
- f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
+ f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
+ tripletFields: TripletFields): Graph[VD, ED2] = {
vertices.cache()
- val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr")
- val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr")
- replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr)
+ replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) =>
- part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr)))
+ part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst)))
}
new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges))
}
@@ -171,15 +169,38 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
override def mapReduceTriplets[A: ClassTag](
mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
reduceFunc: (A, A) => A,
- activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = {
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = {
+
+ def sendMsg(ctx: EdgeContext[VD, ED, A]) {
+ mapFunc(ctx.toEdgeTriplet).foreach { kv =>
+ val id = kv._1
+ val msg = kv._2
+ if (id == ctx.srcId) {
+ ctx.sendToSrc(msg)
+ } else {
+ assert(id == ctx.dstId)
+ ctx.sendToDst(msg)
+ }
+ }
+ }
- vertices.cache()
+ val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
+ val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
+ val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true)
+
+ aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt)
+ }
+
+ override def aggregateMessagesWithActiveSet[A: ClassTag](
+ sendMsg: EdgeContext[VD, ED, A] => Unit,
+ mergeMsg: (A, A) => A,
+ tripletFields: TripletFields,
+ activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = {
+ vertices.cache()
// For each vertex, replicate its attribute only to partitions where it is
// in the relevant position in an edge.
- val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
- val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
- replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr)
+ replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
val view = activeSetOpt match {
case Some((activeSet, _)) =>
replicatedVertexView.withActiveSet(activeSet)
@@ -193,42 +214,40 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
case (pid, edgePartition) =>
// Choose scan method
val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
- val edgeIter = activeDirectionOpt match {
+ activeDirectionOpt match {
case Some(EdgeDirection.Both) =>
if (activeFraction < 0.8) {
- edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId))
- .filter(e => edgePartition.isActive(e.dstId))
+ edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
+ EdgeActiveness.Both)
} else {
- edgePartition.iterator.filter(e =>
- edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId))
+ edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+ EdgeActiveness.Both)
}
case Some(EdgeDirection.Either) =>
// TODO: Because we only have a clustered index on the source vertex ID, we can't filter
// the index here. Instead we have to scan all edges and then do the filter.
- edgePartition.iterator.filter(e =>
- edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId))
+ edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+ EdgeActiveness.Either)
case Some(EdgeDirection.Out) =>
if (activeFraction < 0.8) {
- edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId))
+ edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
+ EdgeActiveness.SrcOnly)
} else {
- edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId))
+ edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+ EdgeActiveness.SrcOnly)
}
case Some(EdgeDirection.In) =>
- edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId))
+ edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+ EdgeActiveness.DstOnly)
case _ => // None
- edgePartition.iterator
+ edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+ EdgeActiveness.Neither)
}
-
- // Scan edges and run the map function
- val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr)
- .flatMap(mapFunc(_))
- // Note: This doesn't allow users to send messages to arbitrary vertices.
- edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
- }).setName("GraphImpl.mapReduceTriplets - preAgg")
+ }).setName("GraphImpl.aggregateMessages - preAgg")
// do the final reduction reusing the index map
- vertices.aggregateUsingIndex(preAgg, reduceFunc)
- } // end of mapReduceTriplets
+ vertices.aggregateUsingIndex(preAgg, mergeMsg)
+ }
override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
(other: RDD[(VertexId, U)])
@@ -304,11 +323,10 @@ object GraphImpl {
*/
def apply[VD: ClassTag, ED: ClassTag](
vertices: VertexRDD[VD],
- edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = {
+ edges: EdgeRDD[ED]): GraphImpl[VD, ED] = {
// Convert the vertex partitions in edges to the correct type
- val newEdges = edges.mapEdgePartitions(
- (pid, part) => part.withVertices(part.vertices.map(
- (vid, attr) => null.asInstanceOf[VD])))
+ val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]]
+ .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD])
GraphImpl.fromExistingRDDs(vertices, newEdges)
}
@@ -319,8 +337,8 @@ object GraphImpl {
*/
def fromExistingRDDs[VD: ClassTag, ED: ClassTag](
vertices: VertexRDD[VD],
- edges: EdgeRDD[ED, VD]): GraphImpl[VD, ED] = {
- new GraphImpl(vertices, new ReplicatedVertexView(edges))
+ edges: EdgeRDD[ED]): GraphImpl[VD, ED] = {
+ new GraphImpl(vertices, new ReplicatedVertexView(edges.asInstanceOf[EdgeRDDImpl[ED, VD]]))
}
/**
@@ -328,7 +346,7 @@ object GraphImpl {
* `defaultVertexAttr`. The vertices will have the same number of partitions as the EdgeRDD.
*/
private def fromEdgeRDD[VD: ClassTag, ED: ClassTag](
- edges: EdgeRDD[ED, VD],
+ edges: EdgeRDDImpl[ED, VD],
defaultVertexAttr: VD,
edgeStorageLevel: StorageLevel,
vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
deleted file mode 100644
index 714f3b81c9dad..0000000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.language.implicitConversions
-import scala.reflect.{classTag, ClassTag}
-
-import org.apache.spark.Partitioner
-import org.apache.spark.graphx.{PartitionID, VertexId}
-import org.apache.spark.rdd.{ShuffledRDD, RDD}
-
-
-private[graphx]
-class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
- def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
- val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner)
-
- // Set a custom serializer if the data is of int or double type.
- if (classTag[VD] == ClassTag.Int) {
- rdd.setSerializer(new IntAggMsgSerializer)
- } else if (classTag[VD] == ClassTag.Long) {
- rdd.setSerializer(new LongAggMsgSerializer)
- } else if (classTag[VD] == ClassTag.Double) {
- rdd.setSerializer(new DoubleAggMsgSerializer)
- }
- rdd
- }
-}
-
-private[graphx]
-object VertexRDDFunctions {
- implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = {
- new VertexRDDFunctions(rdd)
- }
-}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
index 86b366eb9202b..8ab255bd4038c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala
@@ -33,7 +33,7 @@ import org.apache.spark.graphx._
*/
private[impl]
class ReplicatedVertexView[VD: ClassTag, ED: ClassTag](
- var edges: EdgeRDD[ED, VD],
+ var edges: EdgeRDDImpl[ED, VD],
var hasSrcId: Boolean = false,
var hasDstId: Boolean = false) {
@@ -42,7 +42,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag](
* shipping level.
*/
def withEdges[VD2: ClassTag, ED2: ClassTag](
- edges_ : EdgeRDD[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = {
+ edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = {
new ReplicatedVertexView(edges_, hasSrcId, hasDstId)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index b27485953f719..eb3c997e0f3c0 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -29,24 +29,6 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
-private[graphx]
-class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
- /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
- def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
- new ShuffledRDD[VertexId, Int, Int](
- self, partitioner).setSerializer(new RoutingTableMessageSerializer)
- }
-}
-
-private[graphx]
-object RoutingTableMessageRDDFunctions {
- import scala.language.implicitConversions
-
- implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = {
- new RoutingTableMessageRDDFunctions(rdd)
- }
-}
-
private[graphx]
object RoutingTablePartition {
/**
@@ -74,11 +56,9 @@ object RoutingTablePartition {
// Determine which positions each vertex id appears in using a map where the low 2 bits
// represent src and dst
val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte]
- edgePartition.srcIds.iterator.foreach { srcId =>
- map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte)
- }
- edgePartition.dstIds.iterator.foreach { dstId =>
- map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
+ edgePartition.iterator.foreach { e =>
+ map.changeValue(e.srcId, 0x1, (b: Byte) => (b | 0x1).toByte)
+ map.changeValue(e.dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
}
map.iterator.map { vidAndPosition =>
val vid = vidAndPosition._1
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
deleted file mode 100644
index 3909efcdfc993..0000000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ /dev/null
@@ -1,369 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.language.existentials
-
-import java.io.{EOFException, InputStream, OutputStream}
-import java.nio.ByteBuffer
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.serializer._
-
-import org.apache.spark.graphx._
-import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage
-
-private[graphx]
-class RoutingTableMessageSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream): SerializationStream =
- new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T): SerializationStream = {
- val msg = t.asInstanceOf[RoutingTableMessage]
- writeVarLong(msg._1, optimizePositive = false)
- writeInt(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream): DeserializationStream =
- new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readInt()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-private[graphx]
-class VertexIdMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, _)]
- writeVarLong(msg._1, optimizePositive = false)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- (readVarLong(optimizePositive = false), null).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Int]. */
-private[graphx]
-class IntAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Int)]
- writeVarLong(msg._1, optimizePositive = false)
- writeUnsignedVarInt(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readUnsignedVarInt()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Long]. */
-private[graphx]
-class LongAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Long)]
- writeVarLong(msg._1, optimizePositive = false)
- writeVarLong(msg._2, optimizePositive = true)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readVarLong(optimizePositive = true)
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-/** A special shuffle serializer for AggregationMessage[Double]. */
-private[graphx]
-class DoubleAggMsgSerializer extends Serializer with Serializable {
- override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
-
- override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T: ClassTag](t: T) = {
- val msg = t.asInstanceOf[(VertexId, Double)]
- writeVarLong(msg._1, optimizePositive = false)
- writeDouble(msg._2)
- this
- }
- }
-
- override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- def readObject[T: ClassTag](): T = {
- val a = readVarLong(optimizePositive = false)
- val b = readDouble()
- (a, b).asInstanceOf[T]
- }
- }
- }
-}
-
-////////////////////////////////////////////////////////////////////////////////
-// Helper classes to shorten the implementation of those special serializers.
-////////////////////////////////////////////////////////////////////////////////
-
-private[graphx]
-abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
- // The implementation should override this one.
- def writeObject[T: ClassTag](t: T): SerializationStream
-
- def writeInt(v: Int) {
- s.write(v >> 24)
- s.write(v >> 16)
- s.write(v >> 8)
- s.write(v)
- }
-
- def writeUnsignedVarInt(value: Int) {
- if ((value >>> 7) == 0) {
- s.write(value.toInt)
- } else if ((value >>> 14) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7)
- } else if ((value >>> 21) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14)
- } else if ((value >>> 28) == 0) {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14 | 0x80)
- s.write(value >>> 21)
- } else {
- s.write((value & 0x7F) | 0x80)
- s.write(value >>> 7 | 0x80)
- s.write(value >>> 14 | 0x80)
- s.write(value >>> 21 | 0x80)
- s.write(value >>> 28)
- }
- }
-
- def writeVarLong(value: Long, optimizePositive: Boolean) {
- val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value
- if ((v >>> 7) == 0) {
- s.write(v.toInt)
- } else if ((v >>> 14) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7).toInt)
- } else if ((v >>> 21) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14).toInt)
- } else if ((v >>> 28) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21).toInt)
- } else if ((v >>> 35) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28).toInt)
- } else if ((v >>> 42) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35).toInt)
- } else if ((v >>> 49) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42).toInt)
- } else if ((v >>> 56) == 0) {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42 | 0x80).toInt)
- s.write((v >>> 49).toInt)
- } else {
- s.write(((v & 0x7F) | 0x80).toInt)
- s.write((v >>> 7 | 0x80).toInt)
- s.write((v >>> 14 | 0x80).toInt)
- s.write((v >>> 21 | 0x80).toInt)
- s.write((v >>> 28 | 0x80).toInt)
- s.write((v >>> 35 | 0x80).toInt)
- s.write((v >>> 42 | 0x80).toInt)
- s.write((v >>> 49 | 0x80).toInt)
- s.write((v >>> 56).toInt)
- }
- }
-
- def writeLong(v: Long) {
- s.write((v >>> 56).toInt)
- s.write((v >>> 48).toInt)
- s.write((v >>> 40).toInt)
- s.write((v >>> 32).toInt)
- s.write((v >>> 24).toInt)
- s.write((v >>> 16).toInt)
- s.write((v >>> 8).toInt)
- s.write(v.toInt)
- }
-
- def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v))
-
- override def flush(): Unit = s.flush()
-
- override def close(): Unit = s.close()
-}
-
-private[graphx]
-abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
- // The implementation should override this one.
- def readObject[T: ClassTag](): T
-
- def readInt(): Int = {
- val first = s.read()
- if (first < 0) throw new EOFException
- (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
- }
-
- def readUnsignedVarInt(): Int = {
- var value: Int = 0
- var i: Int = 0
- def readOrThrow(): Int = {
- val in = s.read()
- if (in < 0) throw new EOFException
- in & 0xFF
- }
- var b: Int = readOrThrow()
- while ((b & 0x80) != 0) {
- value |= (b & 0x7F) << i
- i += 7
- if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long")
- b = readOrThrow()
- }
- value | (b << i)
- }
-
- def readVarLong(optimizePositive: Boolean): Long = {
- def readOrThrow(): Int = {
- val in = s.read()
- if (in < 0) throw new EOFException
- in & 0xFF
- }
- var b = readOrThrow()
- var ret: Long = b & 0x7F
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 7
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 14
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F) << 21
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 28
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 35
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 42
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= (b & 0x7F).toLong << 49
- if ((b & 0x80) != 0) {
- b = readOrThrow()
- ret |= b.toLong << 56
- }
- }
- }
- }
- }
- }
- }
- }
- if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret
- }
-
- def readLong(): Long = {
- val first = s.read()
- if (first < 0) throw new EOFException()
- (first.toLong << 56) |
- (s.read() & 0xFF).toLong << 48 |
- (s.read() & 0xFF).toLong << 40 |
- (s.read() & 0xFF).toLong << 32 |
- (s.read() & 0xFF).toLong << 24 |
- (s.read() & 0xFF) << 16 |
- (s.read() & 0xFF) << 8 |
- (s.read() & 0xFF)
- }
-
- def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
-
- override def close(): Unit = s.close()
-}
-
-private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance {
-
- override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
-
- override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
- throw new UnsupportedOperationException
-
- override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
- throw new UnsupportedOperationException
-
- // The implementation should override the following two.
- override def serializeStream(s: OutputStream): SerializationStream
- override def deserializeStream(s: InputStream): DeserializationStream
-}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
new file mode 100644
index 0000000000000..d92a55a189298
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx.impl
+
+import scala.reflect.ClassTag
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark.storage.StorageLevel
+
+import org.apache.spark.graphx._
+
+class VertexRDDImpl[VD] private[graphx] (
+ val partitionsRDD: RDD[ShippableVertexPartition[VD]],
+ val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
+ (implicit override protected val vdTag: ClassTag[VD])
+ extends VertexRDD[VD](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) {
+
+ require(partitionsRDD.partitioner.isDefined)
+
+ override def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex()))
+
+ override val partitioner = partitionsRDD.partitioner
+
+ override protected def getPreferredLocations(s: Partition): Seq[String] =
+ partitionsRDD.preferredLocations(s)
+
+ override def setName(_name: String): this.type = {
+ if (partitionsRDD.name != null) {
+ partitionsRDD.setName(partitionsRDD.name + ", " + _name)
+ } else {
+ partitionsRDD.setName(_name)
+ }
+ this
+ }
+ setName("VertexRDD")
+
+ /**
+ * Persists the vertex partitions at the specified storage level, ignoring any existing target
+ * storage level.
+ */
+ override def persist(newLevel: StorageLevel): this.type = {
+ partitionsRDD.persist(newLevel)
+ this
+ }
+
+ override def unpersist(blocking: Boolean = true): this.type = {
+ partitionsRDD.unpersist(blocking)
+ this
+ }
+
+ /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */
+ override def cache(): this.type = {
+ partitionsRDD.persist(targetStorageLevel)
+ this
+ }
+
+ /** The number of vertices in the RDD. */
+ override def count(): Long = {
+ partitionsRDD.map(_.size).reduce(_ + _)
+ }
+
+ override private[graphx] def mapVertexPartitions[VD2: ClassTag](
+ f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2])
+ : VertexRDD[VD2] = {
+ val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true)
+ this.withPartitionsRDD(newPartitionsRDD)
+ }
+
+ override def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] =
+ this.mapVertexPartitions(_.map((vid, attr) => f(attr)))
+
+ override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] =
+ this.mapVertexPartitions(_.map(f))
+
+ override def diff(other: VertexRDD[VD]): VertexRDD[VD] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.diff(otherPart))
+ }
+ this.withPartitionsRDD(newPartitionsRDD)
+ }
+
+ override def leftZipJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.leftJoin(otherPart)(f))
+ }
+ this.withPartitionsRDD(newPartitionsRDD)
+ }
+
+ override def leftJoin[VD2: ClassTag, VD3: ClassTag]
+ (other: RDD[(VertexId, VD2)])
+ (f: (VertexId, VD, Option[VD2]) => VD3)
+ : VertexRDD[VD3] = {
+ // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
+ // If the other set is a VertexRDD then we use the much more efficient leftZipJoin
+ other match {
+ case other: VertexRDD[_] =>
+ leftZipJoin(other)(f)
+ case _ =>
+ this.withPartitionsRDD[VD3](
+ partitionsRDD.zipPartitions(
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
+ (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f))
+ }
+ )
+ }
+ }
+
+ override def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U])
+ (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = {
+ val newPartitionsRDD = partitionsRDD.zipPartitions(
+ other.partitionsRDD, preservesPartitioning = true
+ ) { (thisIter, otherIter) =>
+ val thisPart = thisIter.next()
+ val otherPart = otherIter.next()
+ Iterator(thisPart.innerJoin(otherPart)(f))
+ }
+ this.withPartitionsRDD(newPartitionsRDD)
+ }
+
+ override def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])
+ (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = {
+ // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
+ // If the other set is a VertexRDD then we use the much more efficient innerZipJoin
+ other match {
+ case other: VertexRDD[_] =>
+ innerZipJoin(other)(f)
+ case _ =>
+ this.withPartitionsRDD(
+ partitionsRDD.zipPartitions(
+ other.partitionBy(this.partitioner.get), preservesPartitioning = true) {
+ (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f))
+ }
+ )
+ }
+ }
+
+ override def aggregateUsingIndex[VD2: ClassTag](
+ messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = {
+ val shuffled = messages.partitionBy(this.partitioner.get)
+ val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) =>
+ thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc))
+ }
+ this.withPartitionsRDD[VD2](parts)
+ }
+
+ override def reverseRoutingTables(): VertexRDD[VD] =
+ this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse))
+
+ override def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] = {
+ val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get)
+ val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) {
+ (partIter, routingTableIter) =>
+ val routingTable =
+ if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty
+ partIter.map(_.withRoutingTable(routingTable))
+ }
+ this.withPartitionsRDD(vertexPartitions)
+ }
+
+ override private[graphx] def withPartitionsRDD[VD2: ClassTag](
+ partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = {
+ new VertexRDDImpl(partitionsRDD, this.targetStorageLevel)
+ }
+
+ override private[graphx] def withTargetStorageLevel(
+ targetStorageLevel: StorageLevel): VertexRDD[VD] = {
+ new VertexRDDImpl(this.partitionsRDD, targetStorageLevel)
+ }
+
+ override private[graphx] def shipVertexAttributes(
+ shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = {
+ partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst)))
+ }
+
+ override private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = {
+ partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds()))
+ }
+
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 257e2f3a36115..e139959c3f5c1 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -85,7 +85,7 @@ object PageRank extends Logging {
// Associate the degree with each vertex
.outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) }
// Set the weight on the edges based on the degree
- .mapTriplets( e => 1.0 / e.srcAttr )
+ .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.Src )
// Set the vertex attributes to the initial pagerank values
.mapVertices( (id, attr) => resetProb )
@@ -96,8 +96,8 @@ object PageRank extends Logging {
// Compute the outgoing rank contributions of each vertex, perform local preaggregation, and
// do the final aggregation at the receiving vertices. Requires a shuffle for aggregation.
- val rankUpdates = rankGraph.mapReduceTriplets[Double](
- e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _)
+ val rankUpdates = rankGraph.aggregateMessages[Double](
+ ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.Src)
// Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices
// that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index ccd7de537b6e3..f58587e10a820 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -74,9 +74,9 @@ object SVDPlusPlus {
var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
// Calculate initial bias and norm
- val t0 = g.mapReduceTriplets(
- et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))),
- (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2))
+ val t0 = g.aggregateMessages[(Long, Double)](
+ ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },
+ (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
g = g.outerJoinVertices(t0) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
@@ -84,15 +84,17 @@ object SVDPlusPlus {
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
}
- def mapTrainF(conf: Conf, u: Double)
- (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double])
- : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = {
- val (usr, itm) = (et.srcAttr, et.dstAttr)
+ def sendMsgTrainF(conf: Conf, u: Double)
+ (ctx: EdgeContext[
+ (DoubleMatrix, DoubleMatrix, Double, Double),
+ Double,
+ (DoubleMatrix, DoubleMatrix, Double)]) {
+ val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
val (p, q) = (usr._1, itm._1)
var pred = u + usr._3 + itm._3 + q.dot(usr._2)
pred = math.max(pred, conf.minVal)
pred = math.min(pred, conf.maxVal)
- val err = et.attr - pred
+ val err = ctx.attr - pred
val updateP = q.mul(err)
.subColumnVector(p.mul(conf.gamma7))
.mul(conf.gamma2)
@@ -102,16 +104,16 @@ object SVDPlusPlus {
val updateY = q.mul(err * usr._4)
.subColumnVector(itm._2.mul(conf.gamma7))
.mul(conf.gamma2)
- Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)),
- (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)))
+ ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1))
+ ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))
}
for (i <- 0 until conf.maxIters) {
// Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes
g.cache()
- val t1 = g.mapReduceTriplets(
- et => Iterator((et.srcId, et.dstAttr._2)),
- (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2))
+ val t1 = g.aggregateMessages[DoubleMatrix](
+ ctx => ctx.sendToSrc(ctx.dstAttr._2),
+ (g1, g2) => g1.addColumnVector(g2))
g = g.outerJoinVertices(t1) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[DoubleMatrix]) =>
@@ -121,8 +123,8 @@ object SVDPlusPlus {
// Phase 2, update p for user nodes and q, y for item nodes
g.cache()
- val t2 = g.mapReduceTriplets(
- mapTrainF(conf, u),
+ val t2 = g.aggregateMessages(
+ sendMsgTrainF(conf, u),
(g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
(g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
g = g.outerJoinVertices(t2) {
@@ -135,20 +137,18 @@ object SVDPlusPlus {
}
// calculate error on training set
- def mapTestF(conf: Conf, u: Double)
- (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double])
- : Iterator[(VertexId, Double)] =
- {
- val (usr, itm) = (et.srcAttr, et.dstAttr)
+ def sendMsgTestF(conf: Conf, u: Double)
+ (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) {
+ val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
val (p, q) = (usr._1, itm._1)
var pred = u + usr._3 + itm._3 + q.dot(usr._2)
pred = math.max(pred, conf.minVal)
pred = math.min(pred, conf.maxVal)
- val err = (et.attr - pred) * (et.attr - pred)
- Iterator((et.dstId, err))
+ val err = (ctx.attr - pred) * (ctx.attr - pred)
+ ctx.sendToDst(err)
}
g.cache()
- val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2)
+ val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _)
g = g.outerJoinVertices(t3) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
index 7c396e6e66a28..daf162085e3e4 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -61,26 +61,27 @@ object TriangleCount {
(vid, _, optSet) => optSet.getOrElse(null)
}
// Edge function computes intersection of smaller vertex with larger vertex
- def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexId, Int)] = {
- assert(et.srcAttr != null)
- assert(et.dstAttr != null)
- val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) {
- (et.srcAttr, et.dstAttr)
+ def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) {
+ assert(ctx.srcAttr != null)
+ assert(ctx.dstAttr != null)
+ val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) {
+ (ctx.srcAttr, ctx.dstAttr)
} else {
- (et.dstAttr, et.srcAttr)
+ (ctx.dstAttr, ctx.srcAttr)
}
val iter = smallSet.iterator
var counter: Int = 0
while (iter.hasNext) {
val vid = iter.next()
- if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) {
+ if (vid != ctx.srcId && vid != ctx.dstId && largeSet.contains(vid)) {
counter += 1
}
}
- Iterator((et.srcId, counter), (et.dstId, counter))
+ ctx.sendToSrc(counter)
+ ctx.sendToDst(counter)
}
// compute the intersection along edges
- val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _)
+ val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _)
// Merge counters with the graph and divide by two since each triangle is counted twice
g.outerJoinVertices(counters) {
(vid, _, optCounter: Option[Int]) =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 6506bac73d71c..a05d1ddb21295 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -118,7 +118,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
// Each vertex should be replicated to at most 2 * sqrt(p) partitions
val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter =>
val part = iter.next()._2
- Iterator((part.srcIds ++ part.dstIds).toSet)
+ Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet)
}.collect
if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) {
val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound)
@@ -130,7 +130,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
// This should not be true for the default hash partitioning
val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter =>
val part = iter.next()._2
- Iterator((part.srcIds ++ part.dstIds).toSet)
+ Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet)
}.collect
assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound))
@@ -318,6 +318,21 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
+ test("aggregateMessages") {
+ withSpark { sc =>
+ val n = 5
+ val agg = starGraph(sc, n).aggregateMessages[String](
+ ctx => {
+ if (ctx.dstAttr != null) {
+ throw new Exception(
+ "expected ctx.dstAttr to be null due to TripletFields, but it was " + ctx.dstAttr)
+ }
+ ctx.sendToDst(ctx.srcAttr)
+ }, _ + _, TripletFields.Src)
+ assert(agg.collect().toSet === (1 to n).map(x => (x: VertexId, "v")).toSet)
+ }
+ }
+
test("outerJoinVertices") {
withSpark { sc =>
val n = 5
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
index 47594a800a3b1..a3e28efc75a98 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
@@ -17,9 +17,6 @@
package org.apache.spark.graphx
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterEach
-
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
@@ -31,8 +28,7 @@ trait LocalSparkContext {
/** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */
def withSpark[T](f: SparkContext => T) = {
val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ GraphXUtils.registerKryoClasses(conf)
val sc = new SparkContext("local", "test", conf)
try {
f(sc)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
deleted file mode 100644
index 864cb1fdf0022..0000000000000
--- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx
-
-import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
-
-import scala.util.Random
-import scala.reflect.ClassTag
-
-import org.scalatest.FunSuite
-
-import org.apache.spark._
-import org.apache.spark.graphx.impl._
-import org.apache.spark.serializer.SerializationStream
-
-
-class SerializerSuite extends FunSuite with LocalSparkContext {
-
- test("IntAggMsgSerializer") {
- val outMsg = (4: VertexId, 5)
- val bout = new ByteArrayOutputStream
- val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Int) = inStrm.readObject()
- val inMsg2: (VertexId, Int) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("LongAggMsgSerializer") {
- val outMsg = (4: VertexId, 1L << 32)
- val bout = new ByteArrayOutputStream
- val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Long) = inStrm.readObject()
- val inMsg2: (VertexId, Long) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("DoubleAggMsgSerializer") {
- val outMsg = (4: VertexId, 5.0)
- val bout = new ByteArrayOutputStream
- val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout)
- outStrm.writeObject(outMsg)
- outStrm.writeObject(outMsg)
- bout.flush()
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin)
- val inMsg1: (VertexId, Double) = inStrm.readObject()
- val inMsg2: (VertexId, Double) = inStrm.readObject()
- assert(outMsg === inMsg1)
- assert(outMsg === inMsg2)
-
- intercept[EOFException] {
- inStrm.readObject()
- }
- }
-
- test("variable long encoding") {
- def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
- val bout = new ByteArrayOutputStream
- val stream = new ShuffleSerializationStream(bout) {
- def writeObject[T: ClassTag](t: T): SerializationStream = {
- writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
- this
- }
- }
- stream.writeObject(v)
-
- val bin = new ByteArrayInputStream(bout.toByteArray)
- val dstream = new ShuffleDeserializationStream(bin) {
- def readObject[T: ClassTag](): T = {
- readVarLong(optimizePositive).asInstanceOf[T]
- }
- }
- val read = dstream.readObject[Long]()
- assert(read === v)
- }
-
- // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference)
- val d = Random.nextLong() % 128
- Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d,
- 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number =>
- testVarLongEncoding(number, optimizePositive = false)
- testVarLongEncoding(number, optimizePositive = true)
- testVarLongEncoding(-number, optimizePositive = false)
- testVarLongEncoding(-number, optimizePositive = true)
- }
- }
-}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 9d00f76327e4c..515f3a9cd02eb 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -82,29 +82,6 @@ class EdgePartitionSuite extends FunSuite {
assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges)
}
- test("upgradeIterator") {
- val edges = List((0, 1, 0), (1, 0, 0))
- val verts = List((0L, 1), (1L, 2))
- val part = makeEdgePartition(edges).updateVertices(verts.iterator)
- assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList ===
- part.tripletIterator().toList.map(_.toTuple))
- }
-
- test("indexIterator") {
- val edgesFrom0 = List(Edge(0, 1, 0))
- val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0))
- val sortedEdges = edgesFrom0 ++ edgesFrom1
- val builder = new EdgePartitionBuilder[Int, Nothing]
- for (e <- Random.shuffle(sortedEdges)) {
- builder.add(e.srcId, e.dstId, e.attr)
- }
-
- val edgePartition = builder.toEdgePartition
- assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges)
- assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
- assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
- }
-
test("innerJoin") {
val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0))
@@ -125,21 +102,27 @@ class EdgePartitionSuite extends FunSuite {
assert(ep.numActives == Some(2))
}
+ test("tripletIterator") {
+ val builder = new EdgePartitionBuilder[Int, Int]
+ builder.add(1, 2, 0)
+ builder.add(1, 3, 0)
+ builder.add(1, 4, 0)
+ val ep = builder.toEdgePartition
+ val result = ep.tripletIterator().toList.map(et => (et.srcId, et.dstId))
+ assert(result === Seq((1, 2), (1, 3), (1, 4)))
+ }
+
test("serialization") {
- val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
+ val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5))
val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
val javaSer = new JavaSerializer(new SparkConf())
- val kryoSer = new KryoSerializer(new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+ val conf = new SparkConf()
+ GraphXUtils.registerKryoClasses(conf)
+ val kryoSer = new KryoSerializer(conf)
for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
- assert(aSer.srcIds.toList === a.srcIds.toList)
- assert(aSer.dstIds.toList === a.dstIds.toList)
- assert(aSer.data.toList === a.data.toList)
- assert(aSer.index != null)
- assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
+ assert(aSer.tripletIterator().toList === a.tripletIterator().toList)
}
}
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
deleted file mode 100644
index 49b2704390fea..0000000000000
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.reflect.ClassTag
-import scala.util.Random
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.graphx._
-
-class EdgeTripletIteratorSuite extends FunSuite {
- test("iterator.toList") {
- val builder = new EdgePartitionBuilder[Int, Int]
- builder.add(1, 2, 0)
- builder.add(1, 3, 0)
- builder.add(1, 4, 0)
- val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true)
- val result = iter.toList.map(et => (et.srcId, et.dstId))
- assert(result === Seq((1, 2), (1, 3), (1, 4)))
- }
-}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index f9e771a900013..fe8304c1cdc32 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -125,9 +125,9 @@ class VertexPartitionSuite extends FunSuite {
val verts = Set((0L, 1), (1L, 1), (2L, 1))
val vp = VertexPartition(verts.iterator)
val javaSer = new JavaSerializer(new SparkConf())
- val kryoSer = new KryoSerializer(new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+ val conf = new SparkConf()
+ GraphXUtils.registerKryoClasses(conf)
+ val kryoSer = new KryoSerializer(conf)
for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp))
diff --git a/make-distribution.sh b/make-distribution.sh
index 0bc839e1dbe4d..7c0fb8992a155 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -59,7 +59,7 @@ while (( "$#" )); do
exit_with_usage
;;
--with-hive)
- echo "Error: '--with-hive' is no longer supported, use Maven option -Phive"
+ echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver"
exit_with_usage
;;
--skip-java-test)
@@ -119,7 +119,7 @@ VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "
SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\
| grep -v "INFO"\
| tail -n 1)
-SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles $@ 2>/dev/null\
+SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\
| grep -v "INFO"\
| fgrep --count "hive";\
# Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\
@@ -181,6 +181,9 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI
# Copy jars
cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/"
cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/"
+# This will fail if the -Pyarn profile is not provided
+# In this case, silence the error and ignore the return code of this command
+cp "$FWDIR"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || :
# Copy example sources (needed for python and SQL)
mkdir -p "$DISTDIR/examples/src/main"
diff --git a/mllib/pom.xml b/mllib/pom.xml
index cfeabe4025de6..878aff66b3728 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -45,6 +45,11 @@
spark-streaming_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ org.eclipse.jettyjetty-server
@@ -57,7 +62,7 @@
org.scalanlpbreeze_${scala.binary.version}
- 0.9
+ 0.10
@@ -71,6 +76,10 @@
+
+ org.apache.commons
+ commons-math3
+ org.scalatestscalatest_${scala.binary.version}
@@ -91,6 +100,11 @@
junit-interfacetest
+
+ org.mockito
+ mockito-all
+ test
+ org.apache.sparkspark-streaming_${scala.binary.version}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
new file mode 100644
index 0000000000000..fdbee743e8177
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.api.java.JavaSchemaRDD
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for estimators that fit models to data.
+ */
+@AlphaComponent
+abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
+
+ /**
+ * Fits a single model to the input data with optional parameters.
+ *
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @return fitted model
+ */
+ @varargs
+ def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
+ val map = new ParamMap().put(paramPairs: _*)
+ fit(dataset, map)
+ }
+
+ /**
+ * Fits a single model to the input data with provided parameter map.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted model
+ */
+ def fit(dataset: SchemaRDD, paramMap: ParamMap): M
+
+ /**
+ * Fits multiple models to the input data with multiple sets of parameters.
+ * The default implementation uses a for loop on each parameter map.
+ * Subclasses could overwrite this to optimize multi-model training.
+ *
+ * @param dataset input dataset
+ * @param paramMaps an array of parameter maps
+ * @return fitted models, matching the input parameter maps
+ */
+ def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
+ paramMaps.map(fit(dataset, _))
+ }
+
+ // Java-friendly versions of fit.
+
+ /**
+ * Fits a single model to the input data with optional parameters.
+ *
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @return fitted model
+ */
+ @varargs
+ def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
+ fit(dataset.schemaRDD, paramPairs: _*)
+ }
+
+ /**
+ * Fits a single model to the input data with provided parameter map.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted model
+ */
+ def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
+ fit(dataset.schemaRDD, paramMap)
+ }
+
+ /**
+ * Fits multiple models to the input data with multiple sets of parameters.
+ *
+ * @param dataset input dataset
+ * @param paramMaps an array of parameter maps
+ * @return fitted models, matching the input parameter maps
+ */
+ def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
+ fit(dataset.schemaRDD, paramMaps).asJava
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
new file mode 100644
index 0000000000000..db563dd550e56
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.SchemaRDD
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for evaluators that compute metrics from predictions.
+ */
+@AlphaComponent
+abstract class Evaluator extends Identifiable {
+
+ /**
+ * Evaluates the output.
+ *
+ * @param dataset a dataset that contains labels/observations and predictions.
+ * @param paramMap parameter map that specifies the input columns and output metrics
+ * @return metric
+ */
+ def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
new file mode 100644
index 0000000000000..cd84b05bfb496
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import java.util.UUID
+
+/**
+ * Object with a unique id.
+ */
+private[ml] trait Identifiable extends Serializable {
+
+ /**
+ * A unique id for the object. The default implementation concatenates the class name, "-", and 8
+ * random hex chars.
+ */
+ private[ml] val uid: String =
+ this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
new file mode 100644
index 0000000000000..cae5082b51196
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.ParamMap
+
+/**
+ * :: AlphaComponent ::
+ * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
+ *
+ * @tparam M model type
+ */
+@AlphaComponent
+abstract class Model[M <: Model[M]] extends Transformer {
+ /**
+ * The parent estimator that produced this model.
+ */
+ val parent: Estimator[M]
+
+ /**
+ * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model.
+ */
+ val fittingParamMap: ParamMap
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
new file mode 100644
index 0000000000000..e545df1e37b9c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{Params, Param, ParamMap}
+import org.apache.spark.sql.{SchemaRDD, StructType}
+
+/**
+ * :: AlphaComponent ::
+ * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
+ */
+@AlphaComponent
+abstract class PipelineStage extends Serializable with Logging {
+
+ /**
+ * Derives the output schema from the input schema and parameters.
+ */
+ private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+
+ /**
+ * Derives the output schema from the input schema and parameters, optionally with logging.
+ */
+ protected def transformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ logging: Boolean): StructType = {
+ if (logging) {
+ logDebug(s"Input schema: ${schema.json}")
+ }
+ val outputSchema = transformSchema(schema, paramMap)
+ if (logging) {
+ logDebug(s"Expected output schema: ${outputSchema.json}")
+ }
+ outputSchema
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each
+ * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the
+ * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will
+ * be called on the input dataset to fit a model. Then the model, which is a transformer, will be
+ * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]],
+ * its [[Transformer.transform]] method will be called to produce the dataset for the next stage.
+ * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and
+ * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
+ * an identity transformer.
+ */
+@AlphaComponent
+class Pipeline extends Estimator[PipelineModel] {
+
+ /** param for pipeline stages */
+ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
+ def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
+ def getStages: Array[PipelineStage] = get(stages)
+
+ /**
+ * Fits the pipeline to the input dataset with additional parameters. If a stage is an
+ * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model.
+ * Then the model, which is a transformer, will be used to transform the dataset as the input to
+ * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be
+ * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an
+ * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the
+ * pipeline stages. If there are no stages, the output model acts as an identity transformer.
+ *
+ * @param dataset input dataset
+ * @param paramMap parameter map
+ * @return fitted pipeline
+ */
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+ val theStages = map(stages)
+ // Search for the last estimator.
+ var indexOfLastEstimator = -1
+ theStages.view.zipWithIndex.foreach { case (stage, index) =>
+ stage match {
+ case _: Estimator[_] =>
+ indexOfLastEstimator = index
+ case _ =>
+ }
+ }
+ var curDataset = dataset
+ val transformers = ListBuffer.empty[Transformer]
+ theStages.view.zipWithIndex.foreach { case (stage, index) =>
+ if (index <= indexOfLastEstimator) {
+ val transformer = stage match {
+ case estimator: Estimator[_] =>
+ estimator.fit(curDataset, paramMap)
+ case t: Transformer =>
+ t
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Do not support stage $stage of type ${stage.getClass}")
+ }
+ curDataset = transformer.transform(curDataset, paramMap)
+ transformers += transformer
+ } else {
+ transformers += stage.asInstanceOf[Transformer]
+ }
+ }
+
+ new PipelineModel(this, map, transformers.toArray)
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val theStages = map(stages)
+ require(theStages.toSet.size == theStages.size,
+ "Cannot have duplicate components in a pipeline.")
+ theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Represents a compiled pipeline.
+ */
+@AlphaComponent
+class PipelineModel private[ml] (
+ override val parent: Pipeline,
+ override val fittingParamMap: ParamMap,
+ private[ml] val stages: Array[Transformer])
+ extends Model[PipelineModel] with Logging {
+
+ /**
+ * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
+ * estimator does not exist in the pipeline.
+ */
+ def getModel[M <: Model[M]](stage: Estimator[M]): M = {
+ val matched = stages.filter {
+ case m: Model[_] => m.parent.eq(stage)
+ case _ => false
+ }
+ if (matched.isEmpty) {
+ throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
+ } else if (matched.size > 1) {
+ throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
+ } else {
+ matched.head.asInstanceOf[M]
+ }
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
new file mode 100644
index 0000000000000..490e6609ad311
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param._
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.api.java.JavaSchemaRDD
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * :: AlphaComponent ::
+ * Abstract class for transformers that transform one dataset into another.
+ */
+@AlphaComponent
+abstract class Transformer extends PipelineStage with Params {
+
+ /**
+ * Transforms the dataset with optional parameters
+ * @param dataset input dataset
+ * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @return transformed dataset
+ */
+ @varargs
+ def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
+ val map = new ParamMap()
+ paramPairs.foreach(map.put(_))
+ transform(dataset, map)
+ }
+
+ /**
+ * Transforms the dataset with provided parameter map as additional parameters.
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
+
+ // Java-friendly versions of transform.
+
+ /**
+ * Transforms the dataset with optional parameters.
+ * @param dataset input datset
+ * @param paramPairs optional list of param pairs, overwrite embedded params
+ * @return transformed dataset
+ */
+ @varargs
+ def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = {
+ transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD
+ }
+
+ /**
+ * Transforms the dataset with provided parameter map as additional parameters.
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = {
+ transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD
+ }
+}
+
+/**
+ * Abstract class for transformers that take one input column, apply transformation, and output the
+ * result as a new column.
+ */
+private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
+ extends Transformer with HasInputCol with HasOutputCol with Logging {
+
+ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
+ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]
+
+ /**
+ * Creates the transform function using the given param map. The input param map already takes
+ * account of the embedded param map. So the param values should be determined solely by the input
+ * param map.
+ */
+ protected def createTransformFunc(paramMap: ParamMap): IN => OUT
+
+ /**
+ * Validates the input type. Throw an exception if it is invalid.
+ */
+ protected def validateInputType(inputType: DataType): Unit = {}
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ validateInputType(inputType)
+ if (schema.fieldNames.contains(map(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
+ }
+ val output = ScalaReflection.schemaFor[OUT]
+ val outputFields = schema.fields :+
+ StructField(map(outputCol), output.dataType, output.nullable)
+ StructType(outputFields)
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val udf = this.createTransformFunc(map)
+ dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
new file mode 100644
index 0000000000000..85b8899636ca5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * :: AlphaComponent ::
+ * Params for logistic regression.
+ */
+@AlphaComponent
+private[classification] trait LogisticRegressionParams extends Params
+ with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
+ with HasScoreCol with HasPredictionCol {
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * @param schema input schema
+ * @param paramMap additional parameters
+ * @param fitting whether this is in fitting
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean): StructType = {
+ val map = this.paramMap ++ paramMap
+ val featuresType = schema(map(featuresCol)).dataType
+ // TODO: Support casting Array[Double] and Array[Float] to Vector.
+ require(featuresType.isInstanceOf[VectorUDT],
+ s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
+ if (fitting) {
+ val labelType = schema(map(labelCol)).dataType
+ require(labelType == DoubleType,
+ s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
+ }
+ val fieldNames = schema.fieldNames
+ require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
+ require(!fieldNames.contains(map(predictionCol)),
+ s"Prediction column ${map(predictionCol)} already exists.")
+ val outputFields = schema.fields ++ Seq(
+ StructField(map(scoreCol), DoubleType, false),
+ StructField(map(predictionCol), DoubleType, false))
+ StructType(outputFields)
+ }
+}
+
+/**
+ * Logistic regression.
+ */
+class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams {
+
+ setRegParam(0.1)
+ setMaxIter(100)
+ setThreshold(0.5)
+
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+ def setThreshold(value: Double): this.type = set(threshold, value)
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
+ .map { case Row(label: Double, features: Vector) =>
+ LabeledPoint(label, features)
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+ val lr = new LogisticRegressionWithLBFGS
+ lr.optimizer
+ .setRegParam(map(regParam))
+ .setNumIterations(map(maxIter))
+ val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
+ instances.unpersist()
+ // copy model params
+ Params.inheritValues(map, this, lrm)
+ lrm
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = true)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model produced by [[LogisticRegression]].
+ */
+@AlphaComponent
+class LogisticRegressionModel private[ml] (
+ override val parent: LogisticRegression,
+ override val fittingParamMap: ParamMap,
+ weights: Vector)
+ extends Model[LogisticRegressionModel] with LogisticRegressionParams {
+
+ def setThreshold(value: Double): this.type = set(threshold, value)
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = false)
+ }
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val score: Vector => Double = (v) => {
+ val margin = BLAS.dot(v, weights)
+ 1.0 / (1.0 + math.exp(-margin))
+ }
+ val t = map(threshold)
+ val predict: Double => Double = (score) => {
+ if (score > t) 1.0 else 0.0
+ }
+ dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol))
+ .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
new file mode 100644
index 0000000000000..0b0504e036ec9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.sql.{DoubleType, Row, SchemaRDD}
+
+/**
+ * :: AlphaComponent ::
+ * Evaluator for binary classification, which expects two input columns: score and label.
+ */
+@AlphaComponent
+class BinaryClassificationEvaluator extends Evaluator with Params
+ with HasScoreCol with HasLabelCol {
+
+ /** param for metric name in evaluation */
+ val metricName: Param[String] = new Param(this, "metricName",
+ "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
+ def getMetricName: String = get(metricName)
+ def setMetricName(value: String): this.type = set(metricName, value)
+
+ def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
+ val map = this.paramMap ++ paramMap
+
+ val schema = dataset.schema
+ val scoreType = schema(map(scoreCol)).dataType
+ require(scoreType == DoubleType,
+ s"Score column ${map(scoreCol)} must be double type but found $scoreType")
+ val labelType = schema(map(labelCol)).dataType
+ require(labelType == DoubleType,
+ s"Label column ${map(labelCol)} must be double type but found $labelType")
+
+ import dataset.sqlContext._
+ val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr)
+ .map { case Row(score: Double, label: Double) =>
+ (score, label)
+ }
+ val metrics = new BinaryClassificationMetrics(scoreAndLabels)
+ val metric = map(metricName) match {
+ case "areaUnderROC" =>
+ metrics.areaUnderROC()
+ case "areaUnderPR" =>
+ metrics.areaUnderPR()
+ case other =>
+ throw new IllegalArgumentException(s"Does not support metric $other.")
+ }
+ metrics.unpersist()
+ metric
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
new file mode 100644
index 0000000000000..b98b1755a3584
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.Vector
+
+/**
+ * :: AlphaComponent ::
+ * Maps a sequence of terms to their term frequencies using the hashing trick.
+ */
+@AlphaComponent
+class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
+
+ /** number of features */
+ val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
+ def setNumFeatures(value: Int) = set(numFeatures, value)
+ def getNumFeatures: Int = get(numFeatures)
+
+ override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
+ val hashingTF = new feature.HashingTF(paramMap(numFeatures))
+ hashingTF.transform
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
new file mode 100644
index 0000000000000..896a6b83b67bf
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.dsl._
+
+/**
+ * Params for [[StandardScaler]] and [[StandardScalerModel]].
+ */
+private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
+
+/**
+ * :: AlphaComponent ::
+ * Standardizes features by removing the mean and scaling to unit variance using column summary
+ * statistics on the samples in the training set.
+ */
+@AlphaComponent
+class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
+
+ def setInputCol(value: String): this.type = set(inputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val input = dataset.select(map(inputCol).attr)
+ .map { case Row(v: Vector) =>
+ v
+ }
+ val scaler = new feature.StandardScaler().fit(input)
+ val model = new StandardScalerModel(this, map, scaler)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ require(inputType.isInstanceOf[VectorUDT],
+ s"Input column ${map(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains(map(outputCol)),
+ s"Output column ${map(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[StandardScaler]].
+ */
+@AlphaComponent
+class StandardScalerModel private[ml] (
+ override val parent: StandardScaler,
+ override val fittingParamMap: ParamMap,
+ scaler: feature.StandardScalerModel)
+ extends Model[StandardScalerModel] with StandardScalerParams {
+
+ def setInputCol(value: String): this.type = set(inputCol, value)
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val scale: (Vector) => Vector = (v) => {
+ scaler.transform(v)
+ }
+ dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ val inputType = schema(map(inputCol)).dataType
+ require(inputType.isInstanceOf[VectorUDT],
+ s"Input column ${map(inputCol)} must be a vector column")
+ require(!schema.fieldNames.contains(map(outputCol)),
+ s"Output column ${map(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
new file mode 100644
index 0000000000000..0a6599b64c011
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.sql.{DataType, StringType}
+
+/**
+ * :: AlphaComponent ::
+ * A tokenizer that converts the input string to lowercase and then splits it by white spaces.
+ */
+@AlphaComponent
+class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
+
+ protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
+ _.toLowerCase.split("\\s")
+ }
+
+ protected override def validateInputType(inputType: DataType): Unit = {
+ require(inputType == StringType, s"Input type must be string type but got $inputType.")
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
new file mode 100644
index 0000000000000..00d9c802e930d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * assemble and configure practical machine learning pipelines.
+ */
+@AlphaComponent
+package org.apache.spark.ml;
+
+import org.apache.spark.annotation.AlphaComponent;
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
new file mode 100644
index 0000000000000..51cd48c90432a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
+ * assemble and configure practical machine learning pipelines.
+ */
+package object ml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
new file mode 100644
index 0000000000000..8fd46aef4b99d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+import java.lang.reflect.Modifier
+
+import org.apache.spark.annotation.AlphaComponent
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+import org.apache.spark.ml.Identifiable
+
+/**
+ * :: AlphaComponent ::
+ * A param with self-contained documentation and optionally default value. Primitive-typed param
+ * should use the specialized versions, which are more friendly to Java users.
+ *
+ * @param parent parent object
+ * @param name param name
+ * @param doc documentation
+ * @tparam T param value type
+ */
+@AlphaComponent
+class Param[T] (
+ val parent: Params,
+ val name: String,
+ val doc: String,
+ val defaultValue: Option[T] = None)
+ extends Serializable {
+
+ /**
+ * Creates a param pair with the given value (for Java).
+ */
+ def w(value: T): ParamPair[T] = this -> value
+
+ /**
+ * Creates a param pair with the given value (for Scala).
+ */
+ def ->(value: T): ParamPair[T] = ParamPair(this, value)
+
+ override def toString: String = {
+ if (defaultValue.isDefined) {
+ s"$name: $doc (default: ${defaultValue.get})"
+ } else {
+ s"$name: $doc"
+ }
+ }
+}
+
+// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
+
+/** Specialized version of [[Param[Double]]] for Java. */
+class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None)
+ extends Param[Double](parent, name, doc, defaultValue) {
+
+ override def w(value: Double): ParamPair[Double] = super.w(value)
+}
+
+/** Specialized version of [[Param[Int]]] for Java. */
+class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None)
+ extends Param[Int](parent, name, doc, defaultValue) {
+
+ override def w(value: Int): ParamPair[Int] = super.w(value)
+}
+
+/** Specialized version of [[Param[Float]]] for Java. */
+class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None)
+ extends Param[Float](parent, name, doc, defaultValue) {
+
+ override def w(value: Float): ParamPair[Float] = super.w(value)
+}
+
+/** Specialized version of [[Param[Long]]] for Java. */
+class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None)
+ extends Param[Long](parent, name, doc, defaultValue) {
+
+ override def w(value: Long): ParamPair[Long] = super.w(value)
+}
+
+/** Specialized version of [[Param[Boolean]]] for Java. */
+class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None)
+ extends Param[Boolean](parent, name, doc, defaultValue) {
+
+ override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
+}
+
+/**
+ * A param amd its value.
+ */
+case class ParamPair[T](param: Param[T], value: T)
+
+/**
+ * :: AlphaComponent ::
+ * Trait for components that take parameters. This also provides an internal param map to store
+ * parameter values attached to the instance.
+ */
+@AlphaComponent
+trait Params extends Identifiable with Serializable {
+
+ /** Returns all params. */
+ def params: Array[Param[_]] = {
+ val methods = this.getClass.getMethods
+ methods.filter { m =>
+ Modifier.isPublic(m.getModifiers) &&
+ classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
+ m.getParameterTypes.isEmpty
+ }.sortBy(_.getName)
+ .map(m => m.invoke(this).asInstanceOf[Param[_]])
+ }
+
+ /**
+ * Validates parameter values stored internally plus the input parameter map.
+ * Raises an exception if any parameter is invalid.
+ */
+ def validate(paramMap: ParamMap): Unit = {}
+
+ /**
+ * Validates parameter values stored internally.
+ * Raise an exception if any parameter value is invalid.
+ */
+ def validate(): Unit = validate(ParamMap.empty)
+
+ /**
+ * Returns the documentation of all params.
+ */
+ def explainParams(): String = params.mkString("\n")
+
+ /** Checks whether a param is explicitly set. */
+ def isSet(param: Param[_]): Boolean = {
+ require(param.parent.eq(this))
+ paramMap.contains(param)
+ }
+
+ /** Gets a param by its name. */
+ private[ml] def getParam(paramName: String): Param[Any] = {
+ val m = this.getClass.getMethod(paramName)
+ assert(Modifier.isPublic(m.getModifiers) &&
+ classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
+ m.getParameterTypes.isEmpty)
+ m.invoke(this).asInstanceOf[Param[Any]]
+ }
+
+ /**
+ * Sets a parameter in the embedded param map.
+ */
+ private[ml] def set[T](param: Param[T], value: T): this.type = {
+ require(param.parent.eq(this))
+ paramMap.put(param.asInstanceOf[Param[Any]], value)
+ this
+ }
+
+ /**
+ * Gets the value of a parameter in the embedded param map.
+ */
+ private[ml] def get[T](param: Param[T]): T = {
+ require(param.parent.eq(this))
+ paramMap(param)
+ }
+
+ /**
+ * Internal param map.
+ */
+ protected val paramMap: ParamMap = ParamMap.empty
+}
+
+private[ml] object Params {
+
+ /**
+ * Copies parameter values from the parent estimator to the child model it produced.
+ * @param paramMap the param map that holds parameters of the parent
+ * @param parent the parent estimator
+ * @param child the child model
+ */
+ def inheritValues[E <: Params, M <: E](
+ paramMap: ParamMap,
+ parent: E,
+ child: M): Unit = {
+ parent.params.foreach { param =>
+ if (paramMap.contains(param)) {
+ child.set(child.getParam(param.name), paramMap(param))
+ }
+ }
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * A param to value map.
+ */
+@AlphaComponent
+class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable {
+
+ /**
+ * Creates an empty param map.
+ */
+ def this() = this(mutable.Map.empty[Param[Any], Any])
+
+ /**
+ * Puts a (param, value) pair (overwrites if the input param exists).
+ */
+ def put[T](param: Param[T], value: T): this.type = {
+ map(param.asInstanceOf[Param[Any]]) = value
+ this
+ }
+
+ /**
+ * Puts a list of param pairs (overwrites if the input params exists).
+ */
+ def put(paramPairs: ParamPair[_]*): this.type = {
+ paramPairs.foreach { p =>
+ put(p.param.asInstanceOf[Param[Any]], p.value)
+ }
+ this
+ }
+
+ /**
+ * Optionally returns the value associated with a param or its default.
+ */
+ def get[T](param: Param[T]): Option[T] = {
+ map.get(param.asInstanceOf[Param[Any]])
+ .orElse(param.defaultValue)
+ .asInstanceOf[Option[T]]
+ }
+
+ /**
+ * Gets the value of the input param or its default value if it does not exist.
+ * Raises a NoSuchElementException if there is no value associated with the input param.
+ */
+ def apply[T](param: Param[T]): T = {
+ val value = get(param)
+ if (value.isDefined) {
+ value.get
+ } else {
+ throw new NoSuchElementException(s"Cannot find param ${param.name}.")
+ }
+ }
+
+ /**
+ * Checks whether a parameter is explicitly specified.
+ */
+ def contains(param: Param[_]): Boolean = {
+ map.contains(param.asInstanceOf[Param[Any]])
+ }
+
+ /**
+ * Filters this param map for the given parent.
+ */
+ def filter(parent: Params): ParamMap = {
+ val filtered = map.filterKeys(_.parent == parent)
+ new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
+ }
+
+ /**
+ * Make a copy of this param map.
+ */
+ def copy: ParamMap = new ParamMap(map.clone())
+
+ override def toString: String = {
+ map.map { case (param, value) =>
+ s"\t${param.parent.uid}-${param.name}: $value"
+ }.mkString("{\n", ",\n", "\n}")
+ }
+
+ /**
+ * Returns a new param map that contains parameters in this map and the given map,
+ * where the latter overwrites this if there exists conflicts.
+ */
+ def ++(other: ParamMap): ParamMap = {
+ new ParamMap(this.map ++ other.map)
+ }
+
+
+ /**
+ * Adds all parameters from the input param map into this param map.
+ */
+ def ++=(other: ParamMap): this.type = {
+ this.map ++= other.map
+ this
+ }
+
+ /**
+ * Converts this param map to a sequence of param pairs.
+ */
+ def toSeq: Seq[ParamPair[_]] = {
+ map.toSeq.map { case (param, value) =>
+ ParamPair(param, value)
+ }
+ }
+}
+
+object ParamMap {
+
+ /**
+ * Returns an empty param map.
+ */
+ def empty: ParamMap = new ParamMap()
+
+ /**
+ * Constructs a param map by specifying its entries.
+ */
+ @varargs
+ def apply(paramPairs: ParamPair[_]*): ParamMap = {
+ new ParamMap().put(paramPairs: _*)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
new file mode 100644
index 0000000000000..ef141d3eb2b06
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param
+
+private[ml] trait HasRegParam extends Params {
+ /** param for regularization parameter */
+ val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+ def getRegParam: Double = get(regParam)
+}
+
+private[ml] trait HasMaxIter extends Params {
+ /** param for max number of iterations */
+ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+ def getMaxIter: Int = get(maxIter)
+}
+
+private[ml] trait HasFeaturesCol extends Params {
+ /** param for features column name */
+ val featuresCol: Param[String] =
+ new Param(this, "featuresCol", "features column name", Some("features"))
+ def getFeaturesCol: String = get(featuresCol)
+}
+
+private[ml] trait HasLabelCol extends Params {
+ /** param for label column name */
+ val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label"))
+ def getLabelCol: String = get(labelCol)
+}
+
+private[ml] trait HasScoreCol extends Params {
+ /** param for score column name */
+ val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score"))
+ def getScoreCol: String = get(scoreCol)
+}
+
+private[ml] trait HasPredictionCol extends Params {
+ /** param for prediction column name */
+ val predictionCol: Param[String] =
+ new Param(this, "predictionCol", "prediction column name", Some("prediction"))
+ def getPredictionCol: String = get(predictionCol)
+}
+
+private[ml] trait HasThreshold extends Params {
+ /** param for threshold in (binary) prediction */
+ val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
+ def getThreshold: Double = get(threshold)
+}
+
+private[ml] trait HasInputCol extends Params {
+ /** param for input column name */
+ val inputCol: Param[String] = new Param(this, "inputCol", "input column name")
+ def getInputCol: String = get(inputCol)
+}
+
+private[ml] trait HasOutputCol extends Params {
+ /** param for output column name */
+ val outputCol: Param[String] = new Param(this, "outputCol", "output column name")
+ def getOutputCol: String = get(outputCol)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
new file mode 100644
index 0000000000000..194b9bfd9a9e6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import com.github.fommil.netlib.F2jBLAS
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml._
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.{SchemaRDD, StructType}
+
+/**
+ * Params for [[CrossValidator]] and [[CrossValidatorModel]].
+ */
+private[ml] trait CrossValidatorParams extends Params {
+ /** param for the estimator to be cross-validated */
+ val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+ def getEstimator: Estimator[_] = get(estimator)
+
+ /** param for estimator param maps */
+ val estimatorParamMaps: Param[Array[ParamMap]] =
+ new Param(this, "estimatorParamMaps", "param maps for the estimator")
+ def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
+
+ /** param for the evaluator for selection */
+ val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+ def getEvaluator: Evaluator = get(evaluator)
+
+ /** param for number of folds for cross validation */
+ val numFolds: IntParam =
+ new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
+ def getNumFolds: Int = get(numFolds)
+}
+
+/**
+ * :: AlphaComponent ::
+ * K-fold cross validation.
+ */
+@AlphaComponent
+class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging {
+
+ private val f2jBLAS = new F2jBLAS
+
+ def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+ def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
+ def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+ def setNumFolds(value: Int): this.type = set(numFolds, value)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = {
+ val map = this.paramMap ++ paramMap
+ val schema = dataset.schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val sqlCtx = dataset.sqlContext
+ val est = map(estimator)
+ val eval = map(evaluator)
+ val epm = map(estimatorParamMaps)
+ val numModels = epm.size
+ val metrics = new Array[Double](epm.size)
+ val splits = MLUtils.kFold(dataset, map(numFolds), 0)
+ splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
+ val trainingDataset = sqlCtx.applySchema(training, schema).cache()
+ val validationDataset = sqlCtx.applySchema(validation, schema).cache()
+ // multi-model training
+ logDebug(s"Train split $splitIndex with multiple sets of parameters.")
+ val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+ var i = 0
+ while (i < numModels) {
+ val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
+ logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
+ metrics(i) += metric
+ i += 1
+ }
+ }
+ f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1)
+ logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
+ val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+ logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ logInfo(s"Best cross-validation metric: $bestMetric.")
+ val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
+ val cvModel = new CrossValidatorModel(this, map, bestModel)
+ Params.inheritValues(map, this, cvModel)
+ cvModel
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ map(estimator).transformSchema(schema, paramMap)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model from k-fold cross validation.
+ */
+@AlphaComponent
+class CrossValidatorModel private[ml] (
+ override val parent: CrossValidator,
+ override val fittingParamMap: ParamMap,
+ val bestModel: Model[_])
+ extends Model[CrossValidatorModel] with CrossValidatorParams {
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ bestModel.transform(dataset, paramMap)
+ }
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ bestModel.transformSchema(schema, paramMap)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
new file mode 100644
index 0000000000000..dafe73d82c00a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param._
+
+/**
+ * :: AlphaComponent ::
+ * Builder for a param grid used in grid search-based model selection.
+ */
+@AlphaComponent
+class ParamGridBuilder {
+
+ private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]]
+
+ /**
+ * Sets the given parameters in this grid to fixed values.
+ */
+ def baseOn(paramMap: ParamMap): this.type = {
+ baseOn(paramMap.toSeq: _*)
+ this
+ }
+
+ /**
+ * Sets the given parameters in this grid to fixed values.
+ */
+ @varargs
+ def baseOn(paramPairs: ParamPair[_]*): this.type = {
+ paramPairs.foreach { p =>
+ addGrid(p.param.asInstanceOf[Param[Any]], Seq(p.value))
+ }
+ this
+ }
+
+ /**
+ * Adds a param with multiple values (overwrites if the input param exists).
+ */
+ def addGrid[T](param: Param[T], values: Iterable[T]): this.type = {
+ paramGrid.put(param, values)
+ this
+ }
+
+ // specialized versions of addGrid for Java.
+
+ /**
+ * Adds a double param with multiple values.
+ */
+ def addGrid(param: DoubleParam, values: Array[Double]): this.type = {
+ addGrid[Double](param, values)
+ }
+
+ /**
+ * Adds a int param with multiple values.
+ */
+ def addGrid(param: IntParam, values: Array[Int]): this.type = {
+ addGrid[Int](param, values)
+ }
+
+ /**
+ * Adds a float param with multiple values.
+ */
+ def addGrid(param: FloatParam, values: Array[Float]): this.type = {
+ addGrid[Float](param, values)
+ }
+
+ /**
+ * Adds a long param with multiple values.
+ */
+ def addGrid(param: LongParam, values: Array[Long]): this.type = {
+ addGrid[Long](param, values)
+ }
+
+ /**
+ * Adds a boolean param with true and false.
+ */
+ def addGrid(param: BooleanParam): this.type = {
+ addGrid[Boolean](param, Array(true, false))
+ }
+
+ /**
+ * Builds and returns all combinations of parameters specified by the param grid.
+ */
+ def build(): Array[ParamMap] = {
+ var paramMaps = Array(new ParamMap)
+ paramGrid.foreach { case (param, values) =>
+ val newParamMaps = values.flatMap { v =>
+ paramMaps.map(_.copy.put(param.asInstanceOf[Param[Any]], v))
+ }
+ paramMaps = newParamMaps.toArray
+ }
+ paramMaps
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e9f41758581e3..9f20cd5d00dcd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,6 +18,8 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
+import java.nio.{ByteBuffer, ByteOrder}
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -27,24 +29,27 @@ import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
-import org.apache.spark.mllib.optimization._
+import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
-import org.apache.spark.mllib.tree.DecisionTree
-import org.apache.spark.mllib.tree.impurity._
-import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
+import org.apache.spark.mllib.stat.test.ChiSqTestResult
+import org.apache.spark.mllib.tree.{RandomForest, DecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
+import org.apache.spark.mllib.tree.impurity._
+import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-
/**
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
@@ -69,15 +74,29 @@ class PythonMLLibAPI extends Serializable {
private def trainRegressionModel(
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
- initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
- // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
- learner.disableUncachedWarning()
- val model = learner.run(data.rdd, initialWeights)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.dumps(model.weights))
- ret.add(model.intercept: java.lang.Double)
- ret
+ initialWeights: Vector): JList[Object] = {
+ try {
+ val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
+ }
+
+ /**
+ * Return the Updater from string
+ */
+ def getUpdaterFromString(regType: String): Updater = {
+ if (regType == "l2") {
+ new SquaredL2Updater
+ } else if (regType == "l1") {
+ new L1Updater
+ } else if (regType == null || regType == "none") {
+ new SimpleUpdater
+ } else {
+ throw new IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
+ }
}
/**
@@ -88,10 +107,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer
@@ -99,18 +118,11 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- lrAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- lrAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
- }
+ lrAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
lrAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -122,7 +134,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val lassoAlg = new LassoWithSGD()
lassoAlg.optimizer
.setNumIterations(numIterations)
@@ -132,7 +144,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lassoAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -144,7 +156,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
ridgeAlg.optimizer
.setNumIterations(numIterations)
@@ -154,7 +166,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
ridgeAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -166,9 +178,9 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
SVMAlg.optimizer
@@ -176,18 +188,11 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- SVMAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
- }
+ SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
SVMAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -198,10 +203,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
@@ -209,18 +214,37 @@ class PythonMLLibAPI extends Serializable {
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
- if (regType == "l2") {
- LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
- } else if (regType == "l1") {
- LogRegAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
- }
+ LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
+ trainRegressionModel(
+ LogRegAlg,
+ data,
+ initialWeights)
+ }
+
+ /**
+ * Java stub for Python mllib LogisticRegressionWithLBFGS.train()
+ */
+ def trainLogisticRegressionModelWithLBFGS(
+ data: JavaRDD[LabeledPoint],
+ numIterations: Int,
+ initialWeights: Vector,
+ regParam: Double,
+ regType: String,
+ intercept: Boolean,
+ corrections: Int,
+ tolerance: Double): JList[Object] = {
+ val LogRegAlg = new LogisticRegressionWithLBFGS()
+ LogRegAlg.setIntercept(intercept)
+ LogRegAlg.optimizer
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setNumCorrections(corrections)
+ .setConvergenceTol(tolerance)
+ LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType))
trainRegressionModel(
LogRegAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -228,13 +252,10 @@ class PythonMLLibAPI extends Serializable {
*/
def trainNaiveBayes(
data: JavaRDD[LabeledPoint],
- lambda: Double): java.util.List[java.lang.Object] = {
+ lambda: Double): JList[Object] = {
val model = NaiveBayes.train(data.rdd, lambda)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(Vectors.dense(model.labels))
- ret.add(Vectors.dense(model.pi))
- ret.add(model.theta)
- ret
+ List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
+ map(_.asInstanceOf[Object]).asJava
}
/**
@@ -251,9 +272,26 @@ class PythonMLLibAPI extends Serializable {
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
- // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
- .disableUncachedWarning()
- return kMeansAlg.run(data.rdd)
+ try {
+ kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
+ }
+
+ /**
+ * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
+ */
+ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel)
+ extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) {
+
+ def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
+ predict(SerDe.asTupleRDD(userAndProducts.rdd))
+
+ def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+
}
/**
@@ -263,12 +301,25 @@ class PythonMLLibAPI extends Serializable {
* the Py4J documentation.
*/
def trainALSModel(
- ratings: JavaRDD[Rating],
+ ratingsJRDD: JavaRDD[Rating],
rank: Int,
iterations: Int,
lambda: Double,
- blocks: Int): MatrixFactorizationModel = {
- ALS.train(ratings.rdd, rank, iterations, lambda, blocks)
+ blocks: Int,
+ nonnegative: Boolean,
+ seed: java.lang.Long): MatrixFactorizationModel = {
+
+ val als = new ALS()
+ .setRank(rank)
+ .setIterations(iterations)
+ .setLambda(lambda)
+ .setBlocks(blocks)
+ .setNonnegative(nonnegative)
+
+ if (seed != null) als.setSeed(seed)
+
+ val model = als.run(ratingsJRDD.rdd)
+ new MatrixFactorizationModelWrapper(model)
}
/**
@@ -283,8 +334,121 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int,
- alpha: Double): MatrixFactorizationModel = {
- ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
+ alpha: Double,
+ nonnegative: Boolean,
+ seed: java.lang.Long): MatrixFactorizationModel = {
+
+ val als = new ALS()
+ .setImplicitPrefs(true)
+ .setRank(rank)
+ .setIterations(iterations)
+ .setLambda(lambda)
+ .setBlocks(blocks)
+ .setAlpha(alpha)
+ .setNonnegative(nonnegative)
+
+ if (seed != null) als.setSeed(seed)
+
+ val model = als.run(ratingsJRDD.rdd)
+ new MatrixFactorizationModelWrapper(model)
+ }
+
+ /**
+ * Java stub for Normalizer.transform()
+ */
+ def normalizeVector(p: Double, vector: Vector): Vector = {
+ new Normalizer(p).transform(vector)
+ }
+
+ /**
+ * Java stub for Normalizer.transform()
+ */
+ def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
+ new Normalizer(p).transform(rdd)
+ }
+
+ /**
+ * Java stub for IDF.fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ */
+ def fitStandardScaler(
+ withMean: Boolean,
+ withStd: Boolean,
+ data: JavaRDD[Vector]): StandardScalerModel = {
+ new StandardScaler(withMean, withStd).fit(data.rdd)
+ }
+
+ /**
+ * Java stub for IDF.fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ */
+ def fitIDF(minDocFreq: Int, dataset: JavaRDD[Vector]): IDFModel = {
+ new IDF(minDocFreq).fit(dataset)
+ }
+
+ /**
+ * Java stub for Python mllib Word2Vec fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ * @param dataJRDD input JavaRDD
+ * @param vectorSize size of vector
+ * @param learningRate initial learning rate
+ * @param numPartitions number of partitions
+ * @param numIterations number of iterations
+ * @param seed initial seed for random generator
+ * @return A handle to java Word2VecModelWrapper instance at python side
+ */
+ def trainWord2Vec(
+ dataJRDD: JavaRDD[java.util.ArrayList[String]],
+ vectorSize: Int,
+ learningRate: Double,
+ numPartitions: Int,
+ numIterations: Int,
+ seed: Long): Word2VecModelWrapper = {
+ val word2vec = new Word2Vec()
+ .setVectorSize(vectorSize)
+ .setLearningRate(learningRate)
+ .setNumPartitions(numPartitions)
+ .setNumIterations(numIterations)
+ .setSeed(seed)
+ try {
+ val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
+ new Word2VecModelWrapper(model)
+ } finally {
+ dataJRDD.rdd.unpersist(blocking = false)
+ }
+ }
+
+ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
+ def transform(word: String): Vector = {
+ model.transform(word)
+ }
+
+ /**
+ * Transforms an RDD of words to its vector representation
+ * @param rdd an RDD of words
+ * @return an RDD of vector representations of words
+ */
+ def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
+ rdd.rdd.map(model.transform)
+ }
+
+ def findSynonyms(word: String, num: Int): JList[Object] = {
+ val vec = transform(word)
+ findSynonyms(vec, num)
+ }
+
+ def findSynonyms(vector: Vector, num: Int): JList[Object] = {
+ val result = model.findSynonyms(vector, num)
+ val similarity = Vectors.dense(result.map(_._2))
+ val words = result.map(_._1)
+ List(words, similarity).map(_.asInstanceOf[Object]).asJava
+ }
}
/**
@@ -293,13 +457,13 @@ class PythonMLLibAPI extends Serializable {
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
* @param data Training data
- * @param categoricalFeaturesInfoJMap Categorical features info, as Java map
+ * @param categoricalFeaturesInfo Categorical features info, as Java map
*/
def trainDecisionTreeModel(
data: JavaRDD[LabeledPoint],
algoStr: String,
numClasses: Int,
- categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
+ categoricalFeaturesInfo: JMap[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int,
@@ -315,11 +479,53 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
+ try {
+ DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy)
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
+ }
- DecisionTree.train(data.rdd, strategy)
+ /**
+ * Java stub for Python mllib RandomForest.train().
+ * This stub returns a handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on exit;
+ * see the Py4J documentation.
+ */
+ def trainRandomForestModel(
+ data: JavaRDD[LabeledPoint],
+ algoStr: String,
+ numClasses: Int,
+ categoricalFeaturesInfo: JMap[Int, Int],
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ impurityStr: String,
+ maxDepth: Int,
+ maxBins: Int,
+ seed: Int): RandomForestModel = {
+
+ val algo = Algo.fromString(algoStr)
+ val impurity = Impurities.fromString(impurityStr)
+ val strategy = new Strategy(
+ algo = algo,
+ impurity = impurity,
+ maxDepth = maxDepth,
+ numClassesForClassification = numClasses,
+ maxBins = maxBins,
+ categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap)
+ val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK)
+ try {
+ if (algo == Algo.Classification) {
+ RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed)
+ } else {
+ RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed)
+ }
+ } finally {
+ cached.unpersist(blocking = false)
+ }
}
/**
@@ -346,6 +552,31 @@ class PythonMLLibAPI extends Serializable {
Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method))
}
+ /**
+ * Java stub for mllib Statistics.chiSqTest()
+ */
+ def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = {
+ if (expected == null) {
+ Statistics.chiSqTest(observed)
+ } else {
+ Statistics.chiSqTest(observed, expected)
+ }
+ }
+
+ /**
+ * Java stub for mllib Statistics.chiSqTest(observed: Matrix)
+ */
+ def chiSqTest(observed: Matrix): ChiSqTestResult = {
+ Statistics.chiSqTest(observed)
+ }
+
+ /**
+ * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint])
+ */
+ def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = {
+ Statistics.chiSqTest(data.rdd)
+ }
+
// used by the corr methods to retrieve the name of the correlation method passed in via pyspark
private def getCorrNameOrDefault(method: String) = {
if (method == null) CorrelationNames.defaultCorrName else method
@@ -454,6 +685,7 @@ class PythonMLLibAPI extends Serializable {
private[spark] object SerDe extends Serializable {
val PYSPARK_PACKAGE = "pyspark.mllib"
+ val LATIN1 = "ISO-8859-1"
/**
* Base class used for pickle
@@ -475,7 +707,7 @@ private[spark] object SerDe extends Serializable {
def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
if (obj == this) {
out.write(Opcodes.GLOBAL)
- out.write((module + "\n" + name + "\n").getBytes())
+ out.write((module + "\n" + name + "\n").getBytes)
} else {
pickler.save(this) // it will be memorized by Pickler
saveState(obj, out, pickler)
@@ -487,7 +719,7 @@ private[spark] object SerDe extends Serializable {
if (objects.length == 0 || objects.length > 3) {
out.write(Opcodes.MARK)
}
- objects.foreach(pickler.save(_))
+ objects.foreach(pickler.save)
val code = objects.length match {
case 1 => Opcodes.TUPLE1
case 2 => Opcodes.TUPLE2
@@ -505,7 +737,16 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val vector: DenseVector = obj.asInstanceOf[DenseVector]
- saveObjects(out, pickler, vector.toArray)
+ val bytes = new Array[Byte](8 * vector.size)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ db.put(vector.values)
+
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.TUPLE1)
}
def construct(args: Array[Object]): Object = {
@@ -513,7 +754,13 @@ private[spark] object SerDe extends Serializable {
if (args.length != 1) {
throw new PickleException("should be 1")
}
- new DenseVector(args(0).asInstanceOf[Array[Double]])
+ val bytes = args(0).asInstanceOf[String].getBytes(LATIN1)
+ val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Double](bytes.length / 8)
+ db.get(ans)
+ Vectors.dense(ans)
}
}
@@ -522,15 +769,30 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
- saveObjects(out, pickler, m.numRows, m.numCols, m.values)
+ val bytes = new Array[Byte](8 * m.values.size)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numRows))
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numCols))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.TUPLE3)
}
def construct(args: Array[Object]): Object = {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
- args(2).asInstanceOf[Array[Double]])
+ val bytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val n = bytes.length / 8
+ val values = new Array[Double](n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values)
}
}
@@ -539,15 +801,40 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val v: SparseVector = obj.asInstanceOf[SparseVector]
- saveObjects(out, pickler, v.size, v.indices, v.values)
+ val n = v.indices.size
+ val indiceBytes = new Array[Byte](4 * n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
+ val valueBytes = new Array[Byte](8 * n)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
+
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(v.size))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
+ out.write(indiceBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(valueBytes.length))
+ out.write(valueBytes)
+ out.write(Opcodes.TUPLE3)
}
def construct(args: Array[Object]): Object = {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]],
- args(2).asInstanceOf[Array[Double]])
+ val size = args(0).asInstanceOf[Int]
+ val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1)
+ val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val n = indiceBytes.length / 4
+ val indices = new Array[Int](n)
+ val values = new Array[Double](n)
+ if (n > 0) {
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
+ }
+ new SparseVector(size, indices, values)
}
}
@@ -584,13 +871,24 @@ private[spark] object SerDe extends Serializable {
}
}
+ var initialized = false
+ // This should be called before trying to serialize any above classes
+ // In cluster mode, this should be put in the closure
def initialize(): Unit = {
- new DenseVectorPickler().register()
- new DenseMatrixPickler().register()
- new SparseVectorPickler().register()
- new LabeledPointPickler().register()
- new RatingPickler().register()
+ SerDeUtil.initialize()
+ synchronized {
+ if (!initialized) {
+ new DenseVectorPickler().register()
+ new DenseMatrixPickler().register()
+ new SparseVectorPickler().register()
+ new LabeledPointPickler().register()
+ new RatingPickler().register()
+ initialized = true
+ }
+ }
}
+ // will not called in Executor automatically
+ initialize()
def dumps(obj: AnyRef): Array[Byte] = {
new Pickler().dumps(obj)
@@ -604,4 +902,38 @@ private[spark] object SerDe extends Serializable {
def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
}
+
+ /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
+ def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
+ rdd.map(x => Array(x._1, x._2))
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ new SerDeUtil.AutoBatchedPickler(iter)
+ }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize() // let it called in executor
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
index 87bdc8558aaf5..c67a6d3ae6cce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.api
/**
- * Internal support for MLLib Python API.
+ * Internal support for MLlib Python API.
*
* @see [[org.apache.spark.mllib.api.python.PythonMLLibAPI]]
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 84d3c7cebd7c8..94d757bc317ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -64,16 +64,17 @@ class LogisticRegressionModel (
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
val score = 1.0 / (1.0 + math.exp(-margin))
threshold match {
- case Some(t) => if (score < t) 0.0 else 1.0
+ case Some(t) => if (score > t) 1.0 else 0.0
case None => score
}
}
}
/**
- * Train a classification model for Logistic Regression using Stochastic Gradient Descent.
- * NOTE: Labels used in Logistic Regression should be {0, 1}
- *
+ * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By
+ * default L2 regularization is used, which can be changed via
+ * [[LogisticRegressionWithSGD.optimizer]].
+ * NOTE: Labels used in Logistic Regression should be {0, 1}.
* Using [[LogisticRegressionWithLBFGS]] is recommended over this.
*/
class LogisticRegressionWithSGD private (
@@ -93,9 +94,10 @@ class LogisticRegressionWithSGD private (
override protected val validators = List(DataValidators.binaryLabelValidator)
/**
- * Construct a LogisticRegression object with default parameters
+ * Construct a LogisticRegression object with default parameters: {stepSize: 1.0,
+ * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 0.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 80f8a1b2f1e84..dd514ff8a37f2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -65,14 +65,15 @@ class SVMModel (
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
threshold match {
- case Some(t) => if (margin < t) 0.0 else 1.0
+ case Some(t) => if (margin > t) 1.0 else 0.0
case None => margin
}
}
}
/**
- * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent.
+ * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2
+ * regularization is used, which can be changed via [[SVMWithSGD.optimizer]].
* NOTE: Labels used in SVM should be {0, 1}.
*/
class SVMWithSGD private (
@@ -92,9 +93,10 @@ class SVMWithSGD private (
override protected val validators = List(DataValidators.binaryLabelValidator)
/**
- * Construct a SVM object with default parameters
+ * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100,
+ * regParm: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new SVMModel(weights, intercept)
@@ -185,6 +187,6 @@ object SVMWithSGD {
* @return a SVMModel which has the weights and offset from training.
*/
def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 7443f232ec3e7..34ea0de706f08 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -113,22 +113,13 @@ class KMeans private (
this
}
- /** Whether a warning should be logged if the input RDD is uncached. */
- private var warnOnUncachedInput = true
-
- /** Disable warnings about uncached input. */
- private[spark] def disableUncachedWarning(): this.type = {
- warnOnUncachedInput = false
- this
- }
-
/**
* Train a K-means model on the given set of points; `data` should be cached for high
* performance, because this is an iterative algorithm.
*/
def run(data: RDD[Vector]): KMeansModel = {
- if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
@@ -143,7 +134,7 @@ class KMeans private (
norms.unpersist()
// Warn at the end of the run as well, for increased visibility.
- if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) {
+ if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
new file mode 100644
index 0000000000000..6189dce9b27da
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeansModel extends MLlib's KMeansModel for streaming
+ * algorithms, so it can keep track of a continuously updated weight
+ * associated with each cluster, and also update the model by
+ * doing a single iteration of the standard k-means algorithm.
+ *
+ * The update algorithm uses the "mini-batch" KMeans rule,
+ * generalized to incorporate forgetfullness (i.e. decay).
+ * The update rule (for each cluster) is:
+ *
+ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
+ * n_t+t = n_t * a + m_t
+ *
+ * Where c_t is the previously estimated centroid for that cluster,
+ * n_t is the number of points assigned to it thus far, x_t is the centroid
+ * estimated on the current batch, and m_t is the number of points assigned
+ * to that centroid in the current batch.
+ *
+ * The decay factor 'a' scales the contribution of the clusters as estimated thus far,
+ * by applying a as a discount weighting on the current point when evaluating
+ * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
+ * are determined entirely by recent data. Lower values correspond to
+ * more forgetting.
+ *
+ * Decay can optionally be specified by a half life and associated
+ * time unit. The time unit can either be a batch of data or a single
+ * data point. Considering data arrived at time t, the half life h is defined
+ * such that at time t + h the discount applied to the data from t is 0.5.
+ * The definition remains the same whether the time unit is given
+ * as batches or points.
+ *
+ */
+@DeveloperApi
+class StreamingKMeansModel(
+ override val clusterCenters: Array[Vector],
+ val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
+
+ /** Perform a k-means update on a batch of data. */
+ def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
+
+ // find nearest cluster to each point
+ val closest = data.map(point => (this.predict(point), (point, 1L)))
+
+ // get sums and counts for updating each cluster
+ val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
+ BLAS.axpy(1.0, p2._1, p1._1)
+ (p1._1, p1._2 + p2._2)
+ }
+ val dim = clusterCenters(0).size
+ val pointStats: Array[(Int, (Vector, Long))] = closest
+ .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
+ .collect()
+
+ val discount = timeUnit match {
+ case StreamingKMeans.BATCHES => decayFactor
+ case StreamingKMeans.POINTS =>
+ val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
+ n
+ }.sum
+ math.pow(decayFactor, numNewPoints)
+ }
+
+ // apply discount to weights
+ BLAS.scal(discount, Vectors.dense(clusterWeights))
+
+ // implement update rule
+ pointStats.foreach { case (label, (sum, count)) =>
+ val centroid = clusterCenters(label)
+
+ val updatedWeight = clusterWeights(label) + count
+ val lambda = count / math.max(updatedWeight, 1e-16)
+
+ clusterWeights(label) = updatedWeight
+ BLAS.scal(1.0 - lambda, centroid)
+ BLAS.axpy(lambda / count, sum, centroid)
+
+ // display the updated cluster centers
+ val display = clusterCenters(label).size match {
+ case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
+ case _ => centroid.toArray.mkString("[", ",", "]")
+ }
+
+ logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
+ }
+
+ // Check whether the smallest cluster is dying. If so, split the largest cluster.
+ val weightsWithIndex = clusterWeights.view.zipWithIndex
+ val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
+ val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
+ if (minWeight < 1e-8 * maxWeight) {
+ logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
+ val weight = (maxWeight + minWeight) / 2.0
+ clusterWeights(largest) = weight
+ clusterWeights(smallest) = weight
+ val largestClusterCenter = clusterCenters(largest)
+ val smallestClusterCenter = clusterCenters(smallest)
+ var j = 0
+ while (j < dim) {
+ val x = largestClusterCenter(j)
+ val p = 1e-14 * math.max(math.abs(x), 1.0)
+ largestClusterCenter.toBreeze(j) = x + p
+ smallestClusterCenter.toBreeze(j) = x - p
+ j += 1
+ }
+ }
+
+ this
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeans provides methods for configuring a
+ * streaming k-means analysis, training the model on streaming,
+ * and using the model to make predictions on streaming data.
+ * See KMeansModel for details on algorithm and update rules.
+ *
+ * Use a builder pattern to construct a streaming k-means analysis
+ * in an application, like:
+ *
+ * val model = new StreamingKMeans()
+ * .setDecayFactor(0.5)
+ * .setK(3)
+ * .setRandomCenters(5, 100.0)
+ * .trainOn(DStream)
+ */
+@DeveloperApi
+class StreamingKMeans(
+ var k: Int,
+ var decayFactor: Double,
+ var timeUnit: String) extends Logging {
+
+ def this() = this(2, 1.0, StreamingKMeans.BATCHES)
+
+ protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
+
+ /** Set the number of clusters. */
+ def setK(k: Int): this.type = {
+ this.k = k
+ this
+ }
+
+ /** Set the decay factor directly (for forgetful algorithms). */
+ def setDecayFactor(a: Double): this.type = {
+ this.decayFactor = decayFactor
+ this
+ }
+
+ /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
+ def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
+ if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
+ throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
+ }
+ this.decayFactor = math.exp(math.log(0.5) / halfLife)
+ logInfo("Setting decay factor to: %g ".format (this.decayFactor))
+ this.timeUnit = timeUnit
+ this
+ }
+
+ /** Specify initial centers directly. */
+ def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /**
+ * Initialize random centers, requiring only the number of dimensions.
+ *
+ * @param dim Number of dimensions
+ * @param weight Weight for each center
+ * @param seed Random seed
+ */
+ def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ val random = new XORShiftRandom(seed)
+ val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
+ val weights = Array.fill(k)(weight)
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /** Return the latest model. */
+ def latestModel(): StreamingKMeansModel = {
+ model
+ }
+
+ /**
+ * Update the clustering model by training on batches of data from a DStream.
+ * This operation registers a DStream for training the model,
+ * checks whether the cluster centers have been initialized,
+ * and updates the model using each batch of data from the stream.
+ *
+ * @param data DStream containing vector data
+ */
+ def trainOn(data: DStream[Vector]) {
+ assertInitialized()
+ data.foreachRDD { (rdd, time) =>
+ model = model.update(rdd, decayFactor, timeUnit)
+ }
+ }
+
+ /**
+ * Use the clustering model to make predictions on batches of data from a DStream.
+ *
+ * @param data DStream containing vector data
+ * @return DStream containing predictions
+ */
+ def predictOn(data: DStream[Vector]): DStream[Int] = {
+ assertInitialized()
+ data.map(model.predict)
+ }
+
+ /**
+ * Use the model to make predictions on the values of a DStream and carry over its keys.
+ *
+ * @param data DStream containing (key, feature vector) pairs
+ * @tparam K key type
+ * @return DStream containing the input keys and the predictions as values
+ */
+ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
+ assertInitialized()
+ data.mapValues(model.predict)
+ }
+
+ /** Check whether cluster centers have been initialized. */
+ private[this] def assertInitialized(): Unit = {
+ if (model.clusterCenters == null) {
+ throw new IllegalStateException(
+ "Initial cluster centers must be set before starting predictions")
+ }
+ }
+}
+
+private[clustering] object StreamingKMeans {
+ final val BATCHES = "batches"
+ final val POINTS = "points"
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
index 7858ec602483f..078fbfbe4f0e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala
@@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve {
*/
def of(curve: RDD[(Double, Double)]): Double = {
curve.sliding(2).aggregate(0.0)(
- seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
+ seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points),
combOp = _ + _
)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
new file mode 100644
index 0000000000000..ea10bde5fa252
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+
+/**
+ * Evaluator for multilabel classification.
+ * @param predictionAndLabels an RDD of (predictions, labels) pairs,
+ * both are non-null Arrays, each with unique elements.
+ */
+class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
+
+ private lazy val numDocs: Long = predictionAndLabels.count()
+
+ private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
+ labels}.distinct().count()
+
+ /**
+ * Returns subset accuracy
+ * (for equal sets of labels)
+ */
+ lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
+ predictions.deep == labels.deep
+ }.count().toDouble / numDocs
+
+ /**
+ * Returns accuracy
+ */
+ lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble /
+ (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
+
+
+ /**
+ * Returns Hamming-loss
+ */
+ lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.size + predictions.size - 2 * labels.intersect(predictions).size
+ }.sum / (numDocs * numLabels)
+
+ /**
+ * Returns document-based precision averaged by the number of documents
+ */
+ lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
+ if (predictions.size > 0) {
+ predictions.intersect(labels).size.toDouble / predictions.size
+ } else {
+ 0
+ }
+ }.sum / numDocs
+
+ /**
+ * Returns document-based recall averaged by the number of documents
+ */
+ lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble / labels.size
+ }.sum / numDocs
+
+ /**
+ * Returns document-based f1-measure averaged by the number of documents
+ */
+ lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
+ 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
+ }.sum / numDocs
+
+ private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.intersect(labels)
+ }.countByValue()
+
+ private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.diff(labels)
+ }.countByValue()
+
+ private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
+ labels.diff(predictions)
+ }.countByValue()
+
+ /**
+ * Returns precision for a given label (category)
+ * @param label the label.
+ */
+ def precision(label: Double) = {
+ val tp = tpPerClass(label)
+ val fp = fpPerClass.getOrElse(label, 0L)
+ if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+ }
+
+ /**
+ * Returns recall for a given label (category)
+ * @param label the label.
+ */
+ def recall(label: Double) = {
+ val tp = tpPerClass(label)
+ val fn = fnPerClass.getOrElse(label, 0L)
+ if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
+ }
+
+ /**
+ * Returns f1-measure for a given label (category)
+ * @param label the label.
+ */
+ def f1Measure(label: Double) = {
+ val p = precision(label)
+ val r = recall(label)
+ if((p + r) == 0) 0 else 2 * p * r / (p + r)
+ }
+
+ private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
+ private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp }
+ private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn }
+
+ /**
+ * Returns micro-averaged label-based precision
+ * (equals to micro-averaged document-based precision)
+ */
+ lazy val microPrecision = {
+ val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
+ sumTp.toDouble / (sumTp + sumFp)
+ }
+
+ /**
+ * Returns micro-averaged label-based recall
+ * (equals to micro-averaged document-based recall)
+ */
+ lazy val microRecall = {
+ val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
+ sumTp.toDouble / (sumTp + sumFn)
+ }
+
+ /**
+ * Returns micro-averaged label-based f1-measure
+ * (equals to micro-averaged document-based f1-measure)
+ */
+ lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
+
+ /**
+ * Returns the sequence of labels in ascending order
+ */
+ lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
new file mode 100644
index 0000000000000..93a7353e2c070
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+
+/**
+ * ::Experimental::
+ * Evaluator for ranking algorithms.
+ *
+ * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
+ */
+@Experimental
+class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
+ extends Logging with Serializable {
+
+ /**
+ * Compute the average precision of all the queries, truncated at ranking position k.
+ *
+ * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
+ * computed as #(relevant items retrieved) / k. This formula also applies when the size of the
+ * ground truth set is less than k.
+ *
+ * If a query has an empty ground truth set, zero will be used as precision together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated precision, must be positive
+ * @return the average precision at the first k ranking positions
+ */
+ def precisionAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val n = math.min(pred.length, k)
+ var i = 0
+ var cnt = 0
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ }
+ i += 1
+ }
+ cnt.toDouble / k
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Returns the mean average precision (MAP) of all the queries.
+ * If a query has an empty ground truth set, the average precision will be zero and a log
+ * warining is generated.
+ */
+ lazy val meanAveragePrecision: Double = {
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ var i = 0
+ var cnt = 0
+ var precSum = 0.0
+ val n = pred.length
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ precSum += cnt.toDouble / (i + 1)
+ }
+ i += 1
+ }
+ precSum / labSet.size
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Compute the average NDCG value of all the queries, truncated at ranking position k.
+ * The discounted cumulative gain at position k is computed as:
+ * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
+ * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
+ * implementation, the relevance value is binary.
+
+ * If a query has an empty ground truth set, zero will be used as ndcg together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated ndcg, must be positive
+ * @return the average ndcg at the first k ranking positions
+ */
+ def ndcgAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val labSetSize = labSet.size
+ val n = math.min(math.max(pred.length, labSetSize), k)
+ var maxDcg = 0.0
+ var dcg = 0.0
+ var i = 0
+ while (i < n) {
+ val gain = 1.0 / math.log(i + 2)
+ if (labSet.contains(pred(i))) {
+ dcg += gain
+ }
+ if (i < labSetSize) {
+ maxDcg += gain
+ }
+ i += 1
+ }
+ dcg / maxDcg
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
new file mode 100644
index 0000000000000..693117d820580
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.Logging
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
+
+/**
+ * :: Experimental ::
+ * Evaluator for regression.
+ *
+ * @param predictionAndObservations an RDD of (prediction, observation) pairs.
+ */
+@Experimental
+class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
+
+ /**
+ * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
+ */
+ private lazy val summary: MultivariateStatisticalSummary = {
+ val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
+ case (prediction, observation) => Vectors.dense(observation, observation - prediction)
+ }.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, v) => summary.add(v),
+ (sum1, sum2) => sum1.merge(sum2)
+ )
+ summary
+ }
+
+ /**
+ * Returns the explained variance regression score.
+ * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+ * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
+ */
+ def explainedVariance: Double = {
+ 1 - summary.variance(1) / summary.variance(0)
+ }
+
+ /**
+ * Returns the mean absolute error, which is a risk function corresponding to the
+ * expected value of the absolute error loss or l1-norm loss.
+ */
+ def meanAbsoluteError: Double = {
+ summary.normL1(1) / summary.count
+ }
+
+ /**
+ * Returns the mean squared error, which is a risk function corresponding to the
+ * expected value of the squared error loss or quadratic loss.
+ */
+ def meanSquaredError: Double = {
+ val rmse = summary.normL2(1) / math.sqrt(summary.count)
+ rmse * rmse
+ }
+
+ /**
+ * Returns the root mean squared error, which is defined as the square root of
+ * the mean squared error.
+ */
+ def rootMeanSquaredError: Double = {
+ summary.normL2(1) / math.sqrt(summary.count)
+ }
+
+ /**
+ * Returns R^2^, the coefficient of determination.
+ * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+ */
+ def r2: Double = {
+ 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
index 562663ad36b40..be3319d60ce25 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala
@@ -24,26 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl
def apply(c: BinaryConfusionMatrix): Double
}
-/** Precision. */
+/** Precision. Defined as 1.0 when there are no positive examples. */
private[evaluation] object Precision extends BinaryClassificationMetricComputer {
- override def apply(c: BinaryConfusionMatrix): Double =
- c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives)
+ override def apply(c: BinaryConfusionMatrix): Double = {
+ val totalPositives = c.numTruePositives + c.numFalsePositives
+ if (totalPositives == 0) {
+ 1.0
+ } else {
+ c.numTruePositives.toDouble / totalPositives
+ }
+ }
}
-/** False positive rate. */
+/** False positive rate. Defined as 0.0 when there are no negative examples. */
private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer {
- override def apply(c: BinaryConfusionMatrix): Double =
- c.numFalsePositives.toDouble / c.numNegatives
+ override def apply(c: BinaryConfusionMatrix): Double = {
+ if (c.numNegatives == 0) {
+ 0.0
+ } else {
+ c.numFalsePositives.toDouble / c.numNegatives
+ }
+ }
}
-/** Recall. */
+/** Recall. Defined as 0.0 when there are no positive examples. */
private[evaluation] object Recall extends BinaryClassificationMetricComputer {
- override def apply(c: BinaryConfusionMatrix): Double =
- c.numTruePositives.toDouble / c.numPositives
+ override def apply(c: BinaryConfusionMatrix): Double = {
+ if (c.numPositives == 0) {
+ 0.0
+ } else {
+ c.numTruePositives.toDouble / c.numPositives
+ }
+ }
}
/**
- * F-Measure.
+ * F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples
+ * are false positives.
* @param beta the beta constant in F-Measure
* @see http://en.wikipedia.org/wiki/F1_score
*/
@@ -52,6 +69,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati
override def apply(c: BinaryConfusionMatrix): Double = {
val precision = Precision(c)
val recall = Recall(c)
- (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
+ if (precision + recall == 0) {
+ 0.0
+ } else {
+ (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall)
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index 3afb47767281c..a9c2e23717896 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -17,16 +17,16 @@
package org.apache.spark.mllib.feature
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+import breeze.linalg.{norm => brzNorm}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
/**
* :: Experimental ::
* Normalizes samples individually to unit L^p^ norm
*
- * For any 1 <= p < Double.PositiveInfinity, normalizes samples using
+ * For any 1 <= p < Double.PositiveInfinity, normalizes samples using
* sum(abs(vector).^p^)^(1/p)^ as norm.
*
* For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization.
@@ -47,22 +47,31 @@ class Normalizer(p: Double) extends VectorTransformer {
* @return normalized vector. If the norm of the input is zero, it will return the input vector.
*/
override def transform(vector: Vector): Vector = {
- var norm = vector.toBreeze.norm(p)
+ val norm = brzNorm(vector.toBreeze, p)
if (norm != 0.0) {
// For dense vector, we've to allocate new memory for new output vector.
// However, for sparse vector, the `index` array will not be changed,
// so we can re-use it to save memory.
- vector.toBreeze match {
- case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm)
- case sv: BSV[Double] =>
- val output = new BSV[Double](sv.index, sv.data.clone(), sv.length)
+ vector match {
+ case dv: DenseVector =>
+ val values = dv.values.clone()
+ val size = values.size
var i = 0
- while (i < output.data.length) {
- output.data(i) /= norm
+ while (i < size) {
+ values(i) /= norm
i += 1
}
- Vectors.fromBreeze(output)
+ Vectors.dense(values)
+ case sv: SparseVector =>
+ val values = sv.values.clone()
+ val nnz = values.size
+ var i = 0
+ while (i < nnz) {
+ values(i) /= norm
+ i += 1
+ }
+ Vectors.sparse(sv.size, sv.indices, values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 4dfd1f0ab8134..8c4c5db5258d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -17,11 +17,9 @@
package org.apache.spark.mllib.feature
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
-
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
@@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] (
require(mean.size == variance.size)
- private lazy val factor: BDV[Double] = {
- val f = BDV.zeros[Double](variance.size)
+ private lazy val factor: Array[Double] = {
+ val f = Array.ofDim[Double](variance.size)
var i = 0
while (i < f.size) {
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
@@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] (
f
}
+ // Since `shift` will be only used in `withMean` branch, we have it as
+ // `lazy val` so it will be evaluated in that branch. Note that we don't
+ // want to create this array multiple times in `transform` function.
+ private lazy val shift: Array[Double] = mean.toArray
+
/**
* Applies standardization transformation on a vector.
*
@@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] (
override def transform(vector: Vector): Vector = {
require(mean.size == vector.size)
if (withMean) {
- vector.toBreeze match {
- case dv: BDV[Double] =>
- val output = vector.toBreeze.copy
- var i = 0
- while (i < output.length) {
- output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0)
- i += 1
+ // By default, Scala generates Java methods for member variables. So every time when
+ // the member variables are accessed, `invokespecial` will be called which is expensive.
+ // This can be avoid by having a local reference of `shift`.
+ val localShift = shift
+ vector match {
+ case dv: DenseVector =>
+ val values = dv.values.clone()
+ val size = values.size
+ if (withStd) {
+ // Having a local reference of `factor` to avoid overhead as the comment before.
+ val localFactor = factor
+ var i = 0
+ while (i < size) {
+ values(i) = (values(i) - localShift(i)) * localFactor(i)
+ i += 1
+ }
+ } else {
+ var i = 0
+ while (i < size) {
+ values(i) -= localShift(i)
+ i += 1
+ }
}
- Vectors.fromBreeze(output)
+ Vectors.dense(values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else if (withStd) {
- vector.toBreeze match {
- case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor)
- case sv: BSV[Double] =>
+ // Having a local reference of `factor` to avoid overhead as the comment before.
+ val localFactor = factor
+ vector match {
+ case dv: DenseVector =>
+ val values = dv.values.clone()
+ val size = values.size
+ var i = 0
+ while(i < size) {
+ values(i) *= localFactor(i)
+ i += 1
+ }
+ Vectors.dense(values)
+ case sv: SparseVector =>
// For sparse vector, the `index` array inside sparse vector object will not be changed,
// so we can re-use it to save memory.
- val output = new BSV[Double](sv.index, sv.data.clone(), sv.length)
+ val indices = sv.indices
+ val values = sv.values.clone()
+ val nnz = values.size
var i = 0
- while (i < output.data.length) {
- output.data(i) *= factor(output.index(i))
+ while (i < nnz) {
+ values(i) *= localFactor(indices(i))
i += 1
}
- Vectors.fromBreeze(output)
+ Vectors.sparse(sv.size, indices, values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
index 415a845332d45..7358c1c84f79c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.feature
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
@@ -48,4 +49,14 @@ trait VectorTransformer extends Serializable {
data.map(x => this.transform(x))
}
+ /**
+ * Applies transformation on an JavaRDD[Vector].
+ *
+ * @param data JavaRDD[Vector] to be transformed.
+ * @return transformed JavaRDD[Vector].
+ */
+ def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = {
+ transform(data.rdd)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index fc1444705364a..7960f3cab576f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -67,7 +67,7 @@ private case class VocabWord(
class Word2Vec extends Serializable with Logging {
private var vectorSize = 100
- private var startingAlpha = 0.025
+ private var learningRate = 0.025
private var numPartitions = 1
private var numIterations = 1
private var seed = Utils.random.nextLong()
@@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging {
* Sets initial learning rate (default: 0.025).
*/
def setLearningRate(learningRate: Double): this.type = {
- this.startingAlpha = learningRate
+ this.learningRate = learningRate
this
}
@@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging {
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
- var alpha = startingAlpha
+ var alpha = learningRate
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
@@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging {
lwc = wordCount
// TODO: discount by iteration?
alpha =
- startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
- if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
+ learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
+ if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
wc += sentence.size
@@ -432,18 +432,18 @@ class Word2VecModel private[mllib] (
throw new IllegalStateException(s"$word not in vocabulary")
}
}
-
+
/**
* Find synonyms of a word
* @param word a word
* @param num number of synonyms to find
- * @return array of (word, similarity)
+ * @return array of (word, cosineSimilarity)
*/
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
findSynonyms(vector,num)
}
-
+
/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
@@ -461,4 +461,11 @@ class Word2VecModel private[mllib] (
.tail
.toArray
}
+
+ /**
+ * Returns a map of words to their vector representations.
+ */
+ def getVectors: Map[String, Array[Float]] = {
+ model
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 54ee930d61003..89539e600f48c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -25,7 +25,7 @@ import org.apache.spark.Logging
/**
* BLAS routines for MLlib's vectors and matrices.
*/
-private[mllib] object BLAS extends Serializable with Logging {
+private[spark] object BLAS extends Serializable with Logging {
@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 2cc52e94282ba..327366a1a3a82 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -17,12 +17,10 @@
package org.apache.spark.mllib.linalg
-import java.util.Arrays
+import java.util.{Random, Arrays}
import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM}
-import org.apache.spark.util.random.XORShiftRandom
-
/**
* Trait for a local matrix.
*/
@@ -67,14 +65,14 @@ sealed trait Matrix extends Serializable {
}
/** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */
- def transposeMultiply(y: DenseMatrix): DenseMatrix = {
+ private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = {
val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix]
BLAS.gemm(true, false, 1.0, this, y, 0.0, C)
C
}
/** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */
- def transposeMultiply(y: DenseVector): DenseVector = {
+ private[mllib] def transposeMultiply(y: DenseVector): DenseVector = {
val output = new DenseVector(new Array[Double](numCols))
BLAS.gemv(true, 1.0, this, y, 0.0, output)
output
@@ -291,22 +289,22 @@ object Matrices {
* Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers.
* @param numRows number of rows of the matrix
* @param numCols number of columns of the matrix
+ * @param rng a random number generator
* @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1)
*/
- def rand(numRows: Int, numCols: Int): Matrix = {
- val rand = new XORShiftRandom
- new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble()))
+ def rand(numRows: Int, numCols: Int, rng: Random): Matrix = {
+ new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble()))
}
/**
* Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers.
* @param numRows number of rows of the matrix
* @param numCols number of columns of the matrix
+ * @param rng a random number generator
* @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1)
*/
- def randn(numRows: Int, numCols: Int): Matrix = {
- val rand = new XORShiftRandom
- new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian()))
+ def randn(numRows: Int, numCols: Int, rng: Random): Matrix = {
+ new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian()))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6af225b7f49f7..c6d5fe5bc678c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -17,22 +17,26 @@
package org.apache.spark.mllib.linalg
-import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import java.util
+import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
+import org.apache.spark.mllib.util.NumericParser
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
*
* Note: Users should not implement this interface.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
sealed trait Vector extends Serializable {
/**
@@ -72,6 +76,77 @@ sealed trait Vector extends Serializable {
def copy: Vector = {
throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.")
}
+
+ /**
+ * Applies a function `f` to all the active elements of dense and sparse vector.
+ *
+ * @param f the function takes two parameters where the first parameter is the index of
+ * the vector with type `Int`, and the second parameter is the corresponding value
+ * with type `Double`.
+ */
+ private[spark] def foreachActive(f: (Int, Double) => Unit)
+}
+
+/**
+ * User-defined type for [[Vector]] which allows easy interaction with SQL
+ * via [[org.apache.spark.sql.SchemaRDD]].
+ */
+private[spark] class VectorUDT extends UserDefinedType[Vector] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
+ // vectors. The "values" field is nullable because we might want to add binary vectors later,
+ // which uses "size" and "indices", but not "values".
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("size", IntegerType, nullable = true),
+ StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(4)
+ obj match {
+ case sv: SparseVector =>
+ row.setByte(0, 0)
+ row.setInt(1, sv.size)
+ row.update(2, sv.indices.toSeq)
+ row.update(3, sv.values.toSeq)
+ case dv: DenseVector =>
+ row.setByte(0, 1)
+ row.setNullAt(1)
+ row.setNullAt(2)
+ row.update(3, dv.values.toSeq)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Vector = {
+ datum match {
+ // TODO: something wrong with UDT serialization
+ case v: Vector =>
+ v
+ case row: Row =>
+ require(row.length == 4,
+ s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
+ val tpe = row.getByte(0)
+ tpe match {
+ case 0 =>
+ val size = row.getInt(1)
+ val indices = row.getAs[Iterable[Int]](2).toArray
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new SparseVector(size, indices, values)
+ case 1 =>
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new DenseVector(values)
+ }
+ }
+ }
+
+ override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
+
+ override def userClass: Class[Vector] = classOf[Vector]
}
/**
@@ -171,7 +246,7 @@ object Vectors {
private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = {
breezeVector match {
case v: BDV[Double] =>
- if (v.offset == 0 && v.stride == 1) {
+ if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) {
new DenseVector(v.data)
} else {
new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one
@@ -191,6 +266,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {
override def size: Int = values.length
@@ -206,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def copy: DenseVector = {
new DenseVector(values.clone())
}
+
+ private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
+ var i = 0
+ val localValuesSize = values.size
+ val localValues = values
+
+ while (i < localValuesSize) {
+ f(i, localValues(i))
+ i += 1
+ }
+ }
}
/**
@@ -215,6 +302,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector(
override val size: Int,
val indices: Array[Int],
@@ -241,4 +329,16 @@ class SparseVector(
}
private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
+
+ private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
+ var i = 0
+ val localValuesSize = values.size
+ val localIndices = indices
+ val localValues = values
+
+ while (i < localValuesSize) {
+ f(localIndices(i), localValues(i))
+ i += 1
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index 8380058cf9b41..10a515af88802 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -111,7 +111,10 @@ class RowMatrix(
*/
def computeGramianMatrix(): Matrix = {
val n = numCols().toInt
- val nt: Int = n * (n + 1) / 2
+ checkNumColumns(n)
+ // Computes n*(n+1)/2, avoiding overflow in the multiplication.
+ // This succeeds when n <= 65535, which is checked above
+ val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2))
// Compute the upper triangular part of the gram matrix.
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
@@ -123,6 +126,16 @@ class RowMatrix(
RowMatrix.triuToFull(n, GU.data)
}
+ private def checkNumColumns(cols: Int): Unit = {
+ if (cols > 65535) {
+ throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols")
+ }
+ if (cols > 10000) {
+ val mem = cols * cols * 8
+ logWarning(s"$cols columns will require at least $mem bytes of memory!")
+ }
+ }
+
/**
* Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This
* will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k
@@ -139,7 +152,7 @@ class RowMatrix(
* storing the right singular vectors, is computed via matrix multiplication as
* U = A * (V * S^-1^), if requested by user. The actual method to use is determined
* automatically based on the cost:
- * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian
+ * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian
* matrix first and then compute its top eigenvalues and eigenvectors locally on the driver.
* This requires a single pass with O(n^2^) storage on each executor and on the driver, and
* O(n^2^ k) time on the driver.
@@ -156,7 +169,8 @@ class RowMatrix(
* @note The conditions that decide which method to use internally and the default parameters are
* subject to change.
*
- * @param k number of leading singular values to keep (0 < k <= n). It might return less than k if
+ * @param k number of leading singular values to keep (0 < k <= n).
+ * It might return less than k if
* there are numerically zero singular values or there are not enough Ritz values
* converged before the maximum number of Arnoldi update iterations is reached (in case
* that matrix A is ill-conditioned).
@@ -179,7 +193,7 @@ class RowMatrix(
/**
* The actual SVD implementation, visible for testing.
*
- * @param k number of leading singular values to keep (0 < k <= n)
+ * @param k number of leading singular values to keep (0 < k <= n)
* @param computeU whether to compute U
* @param rCond the reciprocal condition number
* @param maxIter max number of iterations (if ARPACK is used)
@@ -301,12 +315,7 @@ class RowMatrix(
*/
def computeCovariance(): Matrix = {
val n = numCols().toInt
-
- if (n > 10000) {
- val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE
- logWarning(s"The number of columns $n is greater than 10000! " +
- s"We need at least $mem bytes of memory.")
- }
+ checkNumColumns(n)
val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))(
seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze),
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index a6912056395d7..0857877951c82 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -160,14 +160,15 @@ object GradientDescent extends Logging {
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
val numExamples = data.count()
- val miniBatchSize = numExamples * miniBatchFraction
// if no data, return initial weights to avoid NaNs
if (numExamples == 0) {
-
- logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
+ logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
return (initialWeights, stochasticLossHistory.toArray)
+ }
+ if (numExamples * miniBatchFraction < 1) {
+ logWarning("The miniBatchFraction is too small")
}
// Initialize weights as a column vector
@@ -185,25 +186,31 @@ object GradientDescent extends Logging {
val bcWeights = data.context.broadcast(weights)
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
- val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
- .treeAggregate((BDV.zeros[Double](n), 0.0))(
- seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
- val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
- (grad, loss + l)
+ val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i)
+ .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
+ seqOp = (c, v) => {
+ // c: (grad, loss, count), v: (label, features)
+ val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
+ (c._1, c._2 + l, c._3 + 1)
},
- combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
- (grad1 += grad2, loss1 + loss2)
+ combOp = (c1, c2) => {
+ // c: (grad, loss, count)
+ (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
})
- /**
- * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
- * and regVal is the regularization value computed in the previous iteration as well.
- */
- stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
- val update = updater.compute(
- weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
- weights = update._1
- regVal = update._2
+ if (miniBatchSize > 0) {
+ /**
+ * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
+ * and regVal is the regularization value computed in the previous iteration as well.
+ */
+ stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
+ val update = updater.compute(
+ weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
+ weights = update._1
+ regVal = update._2
+ } else {
+ logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
+ }
}
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
index e4b436b023794..fef062e02b6ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
@@ -79,7 +79,7 @@ private[mllib] object NNLS {
// stopping condition
def stop(step: Double, ndir: Double, nx: Double): Boolean = {
((step.isNaN) // NaN
- || (step < 1e-6) // too small or negative
+ || (step < 1e-7) // too small or negative
|| (step > 1e40) // too small; almost certainly numerical problems
|| (ndir < 1e-12 * nx) // gradient relatively too small
|| (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index 28179fbc450c0..51f9b8657c640 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -17,8 +17,7 @@
package org.apache.spark.mllib.random
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
@@ -89,12 +88,13 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] {
@DeveloperApi
class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] {
- private var rng = new Poisson(mean, new DRand)
+ private var rng = new PoissonDistribution(mean)
- override def nextValue(): Double = rng.nextDouble()
+ override def nextValue(): Double = rng.sample()
override def setSeed(seed: Long) {
- rng = new Poisson(mean, new DRand(seed.toInt))
+ rng = new PoissonDistribution(mean)
+ rng.reseedRandomGenerator(seed)
}
override def copy(): PoissonGenerator = new PoissonGenerator(mean)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
index b5e403bc8c14d..57c0768084e41 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd
import scala.language.implicitConversions
import scala.reflect.ClassTag
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.HashPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
@@ -28,8 +29,8 @@ import org.apache.spark.util.Utils
/**
* Machine learning specific RDD functions.
*/
-private[mllib]
-class RDDFunctions[T: ClassTag](self: RDD[T]) {
+@DeveloperApi
+class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable {
/**
* Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding
@@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
* trigger a Spark job if the parent RDD has more than one partitions and the window size is
* greater than 1.
*/
- def sliding(windowSize: Int): RDD[Seq[T]] = {
+ def sliding(windowSize: Int): RDD[Array[T]] = {
require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.")
if (windowSize == 1) {
- self.map(Seq(_))
+ self.map(Array(_))
} else {
new SlidingRDD[T](self, windowSize)
}
@@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) {
}
}
-private[mllib]
+@DeveloperApi
object RDDFunctions {
/** Implicit conversion from an RDD to RDDFunctions. */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
index dd80782c0f001..35e81fcb3de0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala
@@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]
*/
private[mllib]
class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int)
- extends RDD[Seq[T]](parent) {
+ extends RDD[Array[T]](parent) {
require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.")
- override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = {
val part = split.asInstanceOf[SlidingRDDPartition[T]]
(firstParent[T].iterator(part.prev, context) ++ part.tail)
.sliding(windowSize)
.withPartial(false)
+ .map(_.toArray)
}
override def getPreferredLocations(split: Partition): Seq[String] =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 84d192db53e26..90ac252226006 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.{abs, sqrt}
-import scala.util.Random
-import scala.util.Sorting
+import scala.util.{Random, Sorting}
import scala.util.hashing.byteswap32
import org.jblas.{DoubleMatrix, SimpleBlas, Solve}
+import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.{Logging, HashPartitioner, Partitioner}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-import org.apache.spark.mllib.optimization.NNLS
/**
* Out-link information for a user or product block. This includes the original user/product IDs
@@ -325,6 +325,11 @@ class ALS private (
new MatrixFactorizationModel(rank, usersOut, productsOut)
}
+ /**
+ * Java-friendly version of [[ALS.run]].
+ */
+ def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd)
+
/**
* Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors
* for each user (or product), in a distributed fashion.
@@ -741,7 +746,7 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
* @param blocks level of parallelism to split computation into
- * @param alpha confidence parameter (only applies when immplicitPrefs = true)
+ * @param alpha confidence parameter
* @param seed random seed
*/
def trainImplicit(
@@ -768,7 +773,7 @@ object ALS {
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
* @param blocks level of parallelism to split computation into
- * @param alpha confidence parameter (only applies when immplicitPrefs = true)
+ * @param alpha confidence parameter
*/
def trainImplicit(
ratings: RDD[Rating],
@@ -792,6 +797,7 @@ object ALS {
* @param rank number of features to use
* @param iterations number of iterations of ALS (recommended: 10-20)
* @param lambda regularization factor (recommended: 0.01)
+ * @param alpha confidence parameter
*/
def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double)
: MatrixFactorizationModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 66b58ba770160..ed2f8b41bcae5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -17,27 +17,49 @@
package org.apache.spark.mllib.recommendation
+import java.lang.{Integer => JavaInteger}
+
import org.jblas.DoubleMatrix
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.Logging
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.api.python.SerDe
+import org.apache.spark.storage.StorageLevel
/**
* Model representing the result of matrix factorization.
*
+ * Note: If you create the model directly using constructor, please be aware that fast prediction
+ * requires cached user/product features and their associated partitioners.
+ *
* @param rank Rank for the features in this model.
* @param userFeatures RDD of tuples where each tuple represents the userId and
* the features computed for this user.
* @param productFeatures RDD of tuples where each tuple represents the productId
* and the features computed for this product.
*/
-class MatrixFactorizationModel private[mllib] (
+class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
- val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
+ val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
+
+ require(rank > 0)
+ validateFeatures("User", userFeatures)
+ validateFeatures("Product", productFeatures)
+
+ /** Validates factors and warns users if there are performance concerns. */
+ private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
+ require(features.first()._2.size == rank,
+ s"$name feature dimension does not match the rank $rank.")
+ if (features.partitioner.isEmpty) {
+ logWarning(s"$name factor does not have a partitioner. "
+ + "Prediction on individual records could be slow.")
+ }
+ if (features.getStorageLevel == StorageLevel.NONE) {
+ logWarning(s"$name factor is not cached. Prediction could be slow.")
+ }
+ }
+
/** Predict the rating of one user for one product. */
def predict(user: Int, product: Int): Double = {
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
@@ -65,6 +87,13 @@ class MatrixFactorizationModel private[mllib] (
}
}
+ /**
+ * Java-friendly version of [[MatrixFactorizationModel.predict]].
+ */
+ def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = {
+ predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD()
+ }
+
/**
* Recommends products to a user.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index d0fe4179685ca..0287f04e2c777 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -75,6 +75,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
def predict(testData: Vector): Double = {
predictPoint(testData, weights, intercept)
}
+
+ override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept)
}
/**
@@ -134,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
this
}
- /** Whether a warning should be logged if the input RDD is uncached. */
- private var warnOnUncachedInput = true
-
- /** Disable warnings about uncached input. */
- private[spark] def disableUncachedWarning(): this.type = {
- warnOnUncachedInput = false
- this
- }
-
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
@@ -159,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
- if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
@@ -239,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
// Warn at the end of the run as well, for increased visibility.
- if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) {
+ if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
index 17c753c56681f..2067b36f246b3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.regression
+import scala.beans.BeanInfo
+
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
@@ -27,6 +29,7 @@ import org.apache.spark.SparkException
* @param label Label for this data point.
* @param features List of features for this data point.
*/
+@BeanInfo
case class LabeledPoint(label: Double, features: Vector) {
override def toString: String = {
"(%s,%s)".format(label, features)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index cb0d39e759a9f..f9791c6571782 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -67,9 +67,9 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 1.0, miniBatchFraction: 1.0}.
+ * regParam: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new LassoModel(weights, intercept)
@@ -161,6 +161,6 @@ object LassoWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int): LassoModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index a826deb695ee1..c8cad773f5efb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -68,9 +68,9 @@ class RidgeRegressionWithSGD private (
/**
* Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 1.0, miniBatchFraction: 1.0}.
+ * regParam: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new RidgeRegressionModel(weights, intercept)
@@ -143,7 +143,7 @@ object RidgeRegressionWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double): RidgeRegressionModel = {
- train(input, numIterations, stepSize, regParam, 1.0)
+ train(input, numIterations, stepSize, regParam, 0.01)
}
/**
@@ -158,6 +158,6 @@ object RidgeRegressionWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int): RidgeRegressionModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 3025d4837cab4..fcc2a148791bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.mllib.stat
-import breeze.linalg.{DenseVector => BDV}
-
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.{Vectors, Vector}
@@ -40,14 +38,14 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable {
private var n = 0
- private var currMean: BDV[Double] = _
- private var currM2n: BDV[Double] = _
- private var currM2: BDV[Double] = _
- private var currL1: BDV[Double] = _
+ private var currMean: Array[Double] = _
+ private var currM2n: Array[Double] = _
+ private var currM2: Array[Double] = _
+ private var currL1: Array[Double] = _
private var totalCnt: Long = 0
- private var nnz: BDV[Double] = _
- private var currMax: BDV[Double] = _
- private var currMin: BDV[Double] = _
+ private var nnz: Array[Double] = _
+ private var currMax: Array[Double] = _
+ private var currMin: Array[Double] = _
/**
* Add a new sample to this summarizer, and update the statistical summary.
@@ -60,35 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(sample.size > 0, s"Vector should have dimension larger than zero.")
n = sample.size
- currMean = BDV.zeros[Double](n)
- currM2n = BDV.zeros[Double](n)
- currM2 = BDV.zeros[Double](n)
- currL1 = BDV.zeros[Double](n)
- nnz = BDV.zeros[Double](n)
- currMax = BDV.fill(n)(Double.MinValue)
- currMin = BDV.fill(n)(Double.MaxValue)
+ currMean = Array.ofDim[Double](n)
+ currM2n = Array.ofDim[Double](n)
+ currM2 = Array.ofDim[Double](n)
+ currL1 = Array.ofDim[Double](n)
+ nnz = Array.ofDim[Double](n)
+ currMax = Array.fill[Double](n)(Double.MinValue)
+ currMin = Array.fill[Double](n)(Double.MaxValue)
}
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- sample.toBreeze.activeIterator.foreach {
- case (_, 0.0) => // Skip explicit zero elements.
- case (i, value) =>
- if (currMax(i) < value) {
- currMax(i) = value
+ sample.foreachActive { (index, value) =>
+ if (value != 0.0) {
+ if (currMax(index) < value) {
+ currMax(index) = value
}
- if (currMin(i) > value) {
- currMin(i) = value
+ if (currMin(index) > value) {
+ currMin(index) = value
}
- val tmpPrevMean = currMean(i)
- currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
- currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
- currM2(i) += value * value
- currL1(i) += math.abs(value)
+ val prevMean = currMean(index)
+ val diff = value - prevMean
+ currMean(index) = prevMean + diff / (nnz(index) + 1.0)
+ currM2n(index) += (value - currMean(index)) * diff
+ currM2(index) += value * value
+ currL1(index) += math.abs(value)
- nnz(i) += 1.0
+ nnz(index) += 1.0
+ }
}
totalCnt += 1
@@ -107,47 +106,38 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
s"Expecting $n but got ${other.n}.")
totalCnt += other.totalCnt
- val deltaMean: BDV[Double] = currMean - other.currMean
var i = 0
while (i < n) {
- // merge mean together
- if (other.currMean(i) != 0.0) {
- currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2n together
- if (nnz(i) + other.nnz(i) != 0.0) {
- currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
- (nnz(i) + other.nnz(i))
- }
- // merge m2 together
- if (nnz(i) + other.nnz(i) != 0.0) {
+ val thisNnz = nnz(i)
+ val otherNnz = other.nnz(i)
+ val totalNnz = thisNnz + otherNnz
+ if (totalNnz != 0.0) {
+ val deltaMean = other.currMean(i) - currMean(i)
+ // merge mean together
+ currMean(i) += deltaMean * otherNnz / totalNnz
+ // merge m2n together
+ currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
+ // merge m2 together
currM2(i) += other.currM2(i)
- }
- // merge l1 together
- if (nnz(i) + other.nnz(i) != 0.0) {
+ // merge l1 together
currL1(i) += other.currL1(i)
+ // merge max and min
+ currMax(i) = math.max(currMax(i), other.currMax(i))
+ currMin(i) = math.min(currMin(i), other.currMin(i))
}
-
- if (currMax(i) < other.currMax(i)) {
- currMax(i) = other.currMax(i)
- }
- if (currMin(i) > other.currMin(i)) {
- currMin(i) = other.currMin(i)
- }
+ nnz(i) = totalNnz
i += 1
}
- nnz += other.nnz
} else if (totalCnt == 0 && other.totalCnt != 0) {
this.n = other.n
- this.currMean = other.currMean.copy
- this.currM2n = other.currM2n.copy
- this.currM2 = other.currM2.copy
- this.currL1 = other.currL1.copy
+ this.currMean = other.currMean.clone
+ this.currM2n = other.currM2n.clone
+ this.currM2 = other.currM2.clone
+ this.currL1 = other.currL1.clone
this.totalCnt = other.totalCnt
- this.nnz = other.nnz.copy
- this.currMax = other.currMax.copy
- this.currMin = other.currMin.copy
+ this.nnz = other.nnz.clone
+ this.currMax = other.currMax.clone
+ this.currMin = other.currMin.clone
}
this
}
@@ -155,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def mean: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realMean = BDV.zeros[Double](n)
+ val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
realMean(i) = currMean(i) * (nnz(i) / totalCnt)
i += 1
}
- Vectors.fromBreeze(realMean)
+ Vectors.dense(realMean)
}
override def variance: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realVariance = BDV.zeros[Double](n)
+ val realVariance = Array.ofDim[Double](n)
val denominator = totalCnt - 1.0
@@ -182,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
i += 1
}
}
-
- Vectors.fromBreeze(realVariance)
+ Vectors.dense(realVariance)
}
override def count: Long = totalCnt
@@ -191,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
override def numNonzeros: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- Vectors.fromBreeze(nnz)
+ Vectors.dense(nnz)
}
override def max: Vector = {
@@ -202,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
i += 1
}
- Vectors.fromBreeze(currMax)
+ Vectors.dense(currMax)
}
override def min: Vector = {
@@ -213,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
i += 1
}
- Vectors.fromBreeze(currMin)
+ Vectors.dense(currMin)
}
override def normL2: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- val realMagnitude = BDV.zeros[Double](n)
+ val realMagnitude = Array.ofDim[Double](n)
var i = 0
while (i < currM2.size) {
realMagnitude(i) = math.sqrt(currM2(i))
i += 1
}
-
- Vectors.fromBreeze(realMagnitude)
+ Vectors.dense(realMagnitude)
}
override def normL1: Vector = {
require(totalCnt > 0, s"Nothing has been added to this summarizer.")
- Vectors.fromBreeze(currL1)
+
+ Vectors.dense(currL1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
index 0089419c2c5d4..ea82d39b72c03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.stat.test
import breeze.linalg.{DenseMatrix => BDM}
-import cern.jet.stat.Probability.chiSquareComplemented
+import org.apache.commons.math3.distribution.ChiSquaredDistribution
import org.apache.spark.{SparkException, Logging}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
@@ -33,7 +33,7 @@ import scala.collection.mutable
* on an input of type `Matrix` in which independence between columns is assessed.
* We also provide a method for computing the chi-squared statistic between each feature and the
* label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size =
- * number of features in the inpuy RDD.
+ * number of features in the input RDD.
*
* Supported methods for goodness of fit: `pearson` (default)
* Supported methods for independence: `pearson` (default)
@@ -139,7 +139,7 @@ private[stat] object ChiSqTest extends Logging {
}
/*
- * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies.
+ * Pearson's goodness of fit test on the input observed and expected counts/relative frequencies.
* Uniform distribution is assumed when `expected` is not passed in.
*/
def chiSquared(observed: Vector,
@@ -188,12 +188,12 @@ private[stat] object ChiSqTest extends Logging {
}
}
val df = size - 1
- val pValue = chiSquareComplemented(df, statistic)
+ val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic)
new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString)
}
/*
- * Pearon's independence test on the input contingency matrix.
+ * Pearson's independence test on the input contingency matrix.
* TODO: optimize for SparseMatrix when it becomes supported.
*/
def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = {
@@ -238,7 +238,13 @@ private[stat] object ChiSqTest extends Logging {
j += 1
}
val df = (numCols - 1) * (numRows - 1)
- val pValue = chiSquareComplemented(df, statistic)
- new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString)
+ if (df == 0) {
+ // 1 column or 1 row. Constant distribution is independent of anything.
+ // pValue = 1.0 and statistic = 0.0 in this case.
+ new ChiSqTestResult(1.0, 0, 0.0, methodName, NullHypothesis.independence.toString)
+ } else {
+ val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic)
+ new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString)
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b311d10023894..3d91867c896d9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
@@ -56,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return DecisionTreeModel that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
+ def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
- val rfModel = rf.train(input)
+ val rfModel = rf.run(input)
rfModel.trees(0)
}
+ /**
+ * Trains a decision tree model over an RDD. This is deprecated because it hides the static
+ * methods with the same name in Java.
+ */
+ @deprecated("Please use DecisionTree.run instead.", "1.2.0")
+ def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input)
}
object DecisionTree extends Serializable with Logging {
@@ -84,7 +92,7 @@ object DecisionTree extends Serializable with Logging {
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -110,7 +118,7 @@ object DecisionTree extends Serializable with Logging {
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -138,7 +146,7 @@ object DecisionTree extends Serializable with Logging {
maxDepth: Int,
numClassesForClassification: Int): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -175,7 +183,7 @@ object DecisionTree extends Serializable with Logging {
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
- new DecisionTree(strategy).train(input)
+ new DecisionTree(strategy).run(input)
}
/**
@@ -435,6 +443,11 @@ object DecisionTree extends Serializable with Logging {
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
+ * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+ * each value in the array is the data point's node Id
+ * for a corresponding tree. This is used to prevent the need
+ * to pass the entire tree to the executors during
+ * the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
@@ -445,7 +458,8 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
- timer: TimeTracker = new TimeTracker): Unit = {
+ timer: TimeTracker = new TimeTracker,
+ nodeIdCache: Option[NodeIdCache] = None): Unit = {
/*
* The high-level descriptions of the best split optimizations are noted here.
@@ -477,6 +491,37 @@ object DecisionTree extends Serializable with Logging {
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
+ logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
+
+ /**
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: RandomForest.NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Unit = {
+ if (nodeInfo != null) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
+ instanceWeight, featuresForNode)
+ }
+ }
+ }
/**
* Performs a sequential aggregation over a partition.
@@ -495,20 +540,25 @@ object DecisionTree extends Serializable with Logging {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
- val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
- // If the example does not reach a node in this group, then nodeIndex = null.
- if (nodeInfo != null) {
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
- if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
- } else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
- instanceWeight, featuresForNode)
- }
- }
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+ }
+
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
+
agg
}
@@ -532,6 +582,14 @@ object DecisionTree extends Serializable with Logging {
Some(mutableNodeToFeatures.toMap)
}
+ // array of nodes to train indexed by node index in group
+ val nodes = new Array[Node](numNodes)
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
+ }
+ }
+
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
@@ -543,7 +601,26 @@ object DecisionTree extends Serializable with Logging {
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
- val nodeToBestSplits =
+
+ val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+ input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }
+ } else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
@@ -560,7 +637,10 @@ object DecisionTree extends Serializable with Logging {
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
- }.reduceByKey((a, b) => a.merge(b))
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
@@ -568,12 +648,19 @@ object DecisionTree extends Serializable with Logging {
// find best split for each node
val (split: Split, stats: InformationGainStats, predict: Predict) =
- binsToBestSplit(aggStats, splits, featuresForNode)
+ binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats, predict))
}.collectAsMap()
timer.stop("chooseSplits")
+ val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
+ Array.fill[mutable.Map[Int, NodeIndexUpdater]](
+ metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
+ } else {
+ null
+ }
+
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
@@ -587,17 +674,37 @@ object DecisionTree extends Serializable with Logging {
// Extract info for this node. Create children if not leaf.
val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
assert(node.id == nodeIndex)
- node.predict = predict.predict
+ node.predict = predict
node.isLeaf = isLeaf
node.stats = Some(stats)
+ node.impurity = stats.impurity
logDebug("Node = " + node)
if (!isLeaf) {
node.split = Some(split)
- node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
- node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
- nodeQueue.enqueue((treeIndex, node.leftNode.get))
- nodeQueue.enqueue((treeIndex, node.rightNode.get))
+ val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
+ val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
+ val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
+ stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+ node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
+ stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+
+ if (nodeIdCache.nonEmpty) {
+ val nodeIndexUpdater = NodeIndexUpdater(
+ split = split,
+ nodeIndex = nodeIndex)
+ nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
+ }
+
+ // enqueue left child and right child if they are not leaves
+ if (!leftChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.leftNode.get))
+ }
+ if (!rightChildIsLeaf) {
+ nodeQueue.enqueue((treeIndex, node.rightNode.get))
+ }
+
logDebug("leftChildIndex = " + node.leftNode.get.id +
", impurity = " + stats.leftImpurity)
logDebug("rightChildIndex = " + node.rightNode.get.id +
@@ -606,6 +713,10 @@ object DecisionTree extends Serializable with Logging {
}
}
+ if (nodeIdCache.nonEmpty) {
+ // Update the cache if needed.
+ nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
+ }
}
/**
@@ -617,7 +728,8 @@ object DecisionTree extends Serializable with Logging {
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
- metadata: DecisionTreeMetadata): InformationGainStats = {
+ metadata: DecisionTreeMetadata,
+ impurity: Double): InformationGainStats = {
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
@@ -630,11 +742,6 @@ object DecisionTree extends Serializable with Logging {
val totalCount = leftCount + rightCount
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
-
- val impurity = parentNodeAgg.calculate()
-
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -649,7 +756,18 @@ object DecisionTree extends Serializable with Logging {
return InformationGainStats.invalidInformationGainStats
}
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+ // calculate left and right predict
+ val leftPredict = calculatePredict(leftImpurityCalculator)
+ val rightPredict = calculatePredict(rightImpurityCalculator)
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
+ leftPredict, rightPredict)
+ }
+
+ private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
+ val predict = impurityCalculator.predict
+ val prob = impurityCalculator.prob(predict)
+ new Predict(predict, prob)
}
/**
@@ -657,17 +775,17 @@ object DecisionTree extends Serializable with Logging {
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
* @param rightImpurityCalculator right node aggregates for a split
- * @return predict value for current node
+ * @return predict value and impurity for current node
*/
- private def calculatePredict(
+ private def calculatePredictImpurity(
leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator): Predict = {
+ rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
val parentNodeAgg = leftImpurityCalculator.copy
parentNodeAgg.add(rightImpurityCalculator)
- val predict = parentNodeAgg.predict
- val prob = parentNodeAgg.prob(predict)
+ val predict = calculatePredict(parentNodeAgg)
+ val impurity = parentNodeAgg.calculate()
- new Predict(predict, prob)
+ (predict, impurity)
}
/**
@@ -678,10 +796,16 @@ object DecisionTree extends Serializable with Logging {
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
- featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
+ featuresForNode: Option[Array[Int]],
+ node: Node): (Split, InformationGainStats, Predict) = {
- // calculate predict only once
- var predict: Option[Predict] = None
+ // calculate predict and impurity if current node is top node
+ val level = Node.indexToLevel(node.id)
+ var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
+ None
+ } else {
+ Some((node.predict, node.impurity))
+ }
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
@@ -708,9 +832,10 @@ object DecisionTree extends Serializable with Logging {
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata)
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -722,9 +847,10 @@ object DecisionTree extends Serializable with Logging {
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata)
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -794,9 +920,10 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
- predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata)
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
@@ -807,9 +934,7 @@ object DecisionTree extends Serializable with Logging {
}
}.maxBy(_._2.gain)
- assert(predict.isDefined, "must calculate predict for each node")
-
- (bestSplit, bestSplitStats, predict.get)
+ (bestSplit, bestSplitStats, predictWithImpurity.get._1)
}
/**
@@ -874,32 +999,39 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
- val numSplits = metadata.numSplits(featureIndex)
- val numBins = metadata.numBins(featureIndex)
if (metadata.isContinuous(featureIndex)) {
- val numSamples = sampledInput.length
+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
+ val featureSplits = findSplitsForContinuousFeature(featureSamples,
+ metadata, featureIndex)
+
+ val numSplits = featureSplits.length
+ val numBins = numSplits + 1
+ logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
- val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
- val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
- logDebug("stride = " + stride)
- for (splitIndex <- 0 until numSplits) {
- val sampleIndex = splitIndex * stride.toInt
- // Set threshold halfway in between 2 samples.
- val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
+
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
+ splitIndex += 1
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
- for (splitIndex <- 1 until numSplits) {
+
+ splitIndex = 1
+ while (splitIndex < numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
+ splitIndex += 1
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
} else {
+ val numSplits = metadata.numSplits(featureIndex)
+ val numBins = metadata.numBins(featureIndex)
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
@@ -976,4 +1108,77 @@ object DecisionTree extends Serializable with Logging {
categories
}
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ * @param featureSamples feature values of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of splits
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Array[Double],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ // get count for each distinct value
+ val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
+ m + ((x, m.getOrElse(x, 0) + 1))
+ }
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ // if possible splits is not enough or just enough, just return all possible splits
+ val possibleSplits = valueCounts.length
+ if (possibleSplits <= numSplits) {
+ valueCounts.map(_._1)
+ } else {
+ // stride between splits
+ val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+ logDebug("stride = " + stride)
+
+ // iterate `valueCount` to find splits
+ val splits = new ArrayBuffer[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splits.append(valueCounts(index - 1)._1)
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splits.toArray
+ }
+ }
+
+ assert(splits.length > 0)
+ // set number of splits accordingly
+ metadata.setNumSplits(featureIndex, splits.length)
+
+ splits
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
new file mode 100644
index 0000000000000..61f6b1313f82e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.BoostingStrategy
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * :: Experimental ::
+ * A class that implements
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]]
+ * for regression and binary classification.
+ *
+ * The implementation is based upon:
+ * J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - When the loss is SquaredError, these methods give the same result, but they could differ
+ * for other loss functions.
+ *
+ * @param boostingStrategy Parameters for the gradient boosting algorithm.
+ */
+@Experimental
+class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
+ extends Serializable with Logging {
+
+ /**
+ * Method to train a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
+ case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoostedTrees.boost(remappedInput, boostingStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
+ */
+ def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+ run(input.rdd)
+ }
+}
+
+
+object GradientBoostedTrees extends Logging {
+
+ /**
+ * Method to train a gradient boosting model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ new GradientBoostedTrees(boostingStrategy).run(input)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]]
+ */
+ def train(
+ input: JavaRDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ train(input.rdd, boostingStrategy)
+ }
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param boostingStrategy boosting parameters
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ private def boost(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ boostingStrategy.assertValid()
+
+ // Initialize gradient boosting parameters
+ val numIterations = boostingStrategy.numIterations
+ val baseLearners = new Array[DecisionTreeModel](numIterations)
+ val baseLearnerWeights = new Array[Double](numIterations)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+ // Prepare strategy for individual trees, which use regression with variance impurity.
+ val treeStrategy = boostingStrategy.treeStrategy.copy
+ treeStrategy.algo = Regression
+ treeStrategy.impurity = Variance
+ treeStrategy.assertValid()
+
+ // Cache input
+ if (input.getStorageLevel == StorageLevel.NONE) {
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+ var data = input
+
+ // Initialize tree
+ timer.start("building tree 0")
+ val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = 1.0
+ val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
+ logDebug("error of gbt = " + loss.computeError(startingModel, input))
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ // psuedo-residual for second iteration
+ data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
+ point.features))
+
+ var m = 1
+ while (m < numIterations) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+ val model = new DecisionTree(treeStrategy).run(data)
+ timer.stop(s"building tree $m")
+ // Create partial model
+ baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
+ // Technically, the weight should be optimized for the particular loss.
+ // However, the behavior should be reasonable, though not optimal.
+ baseLearnerWeights(m) = learningRate
+ // Note: A model of type regression is used since we require raw prediction
+ val partialModel = new GradientBoostedTreesModel(
+ Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
+ logDebug("error of gbt = " + loss.computeError(partialModel, input))
+ // Update data with pseudo-residuals
+ data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
+ point.features))
+ m += 1
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index fa7a26f17c3ca..482d3395516e7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -17,17 +17,18 @@
package org.apache.spark.mllib.tree
-import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.collection.JavaConverters._
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache,
+ TimeTracker, TreePoint}
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
@@ -36,7 +37,8 @@ import org.apache.spark.util.Utils
/**
* :: Experimental ::
- * A class which implements a random forest learning algorithm for classification and regression.
+ * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]]
+ * learning algorithm for classification and regression.
* It supports both continuous and categorical features.
*
* The settings for featureSubsetStrategy are based on the following references:
@@ -59,7 +61,7 @@ import org.apache.spark.util.Utils
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
- * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
*/
@Experimental
private class RandomForest (
@@ -69,6 +71,47 @@ private class RandomForest (
private val seed: Int)
extends Serializable with Logging {
+ /*
+ ALGORITHM
+ This is a sketch of the algorithm to help new developers.
+
+ The algorithm partitions data by instances (rows).
+ On each iteration, the algorithm splits a set of nodes. In order to choose the best split
+ for a given node, sufficient statistics are collected from the distributed data.
+ For each node, the statistics are collected to some worker node, and that worker selects
+ the best split.
+
+ This setup requires discretization of continuous features. This binning is done in the
+ findSplitsBins() method during initialization, after which each continuous feature becomes
+ an ordered discretized feature with at most maxBins possible values.
+
+ The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes
+ lie at the periphery of the tree being trained. If multiple trees are being trained at once,
+ then this queue contains nodes from all of them. Each iteration works roughly as follows:
+ On the master node:
+ - Some number of nodes are pulled off of the queue (based on the amount of memory
+ required for their sufficient statistics).
+ - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+ features are chosen for each node. See method selectNodesToSplit().
+ On worker nodes, via method findBestSplits():
+ - The worker makes one pass over its subset of instances.
+ - For each (tree, node, feature, split) tuple, the worker collects statistics about
+ splitting. Note that the set of (tree, node) pairs is limited to the nodes selected
+ from the queue for this iteration. The set of features considered can also be limited
+ based on featureSubsetStrategy.
+ - For each node, the statistics for that node are aggregated to a particular worker
+ via reduceByKey(). The designated worker chooses the best (feature, split) pair,
+ or chooses to stop splitting if the stopping criteria are met.
+ On the master node:
+ - The master collects all decisions about splitting nodes and updates the model.
+ - The updated model is passed to the workers on the next iteration.
+ This process continues until the node queue is empty.
+
+ Most of the methods in this implementation support the statistics aggregation, which is
+ the heaviest part of the computation. In general, this implementation is bound by either
+ the cost of statistics computation on workers or by communicating the sufficient statistics.
+ */
+
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
@@ -78,9 +121,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return RandomForestModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): RandomForestModel = {
+ def run(input: RDD[LabeledPoint]): RandomForestModel = {
val timer = new TimeTracker()
@@ -111,11 +154,20 @@ private class RandomForest (
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val baggedInput = if (numTrees > 1) {
- BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
- } else {
- BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
- }.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val (subsample, withReplacement) = {
+ // TODO: Have a stricter check for RF in the strategy
+ val isRandomForest = numTrees > 1
+ if (isRandomForest) {
+ (1.0, true)
+ } else {
+ (strategy.subsamplingRate, false)
+ }
+ }
+
+ val baggedInput
+ = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
+ .persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
@@ -150,6 +202,19 @@ private class RandomForest (
* in lower levels).
*/
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ val nodeIdCache = if (strategy.useNodeIdCache) {
+ Some(NodeIdCache.init(
+ data = baggedInput,
+ numTrees = numTrees,
+ checkpointDir = strategy.checkpointDir,
+ checkpointInterval = strategy.checkpointInterval,
+ initVal = 1))
+ } else {
+ None
+ }
+
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
@@ -172,17 +237,24 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
- treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
}
+ baggedInput.unpersist()
+
timer.stop("total")
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
+ // Delete any remaining checkpoints used for node Id cache.
+ if (nodeIdCache.nonEmpty) {
+ nodeIdCache.get.deleteAllCheckpoints()
+ }
+
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
- RandomForestModel.build(trees)
+ new RandomForestModel(strategy.algo, trees)
}
}
@@ -200,10 +272,9 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "sqrt".
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
@@ -214,7 +285,7 @@ object RandomForest extends Serializable with Logging {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
- rf.train(input)
+ rf.run(input)
}
/**
@@ -231,8 +302,7 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "sqrt".
* @param impurity Criterion used for information gain calculation.
* Supported values: "gini" (recommended) or "entropy".
* @param maxDepth Maximum depth of the tree.
@@ -241,7 +311,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
@@ -288,10 +358,9 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "onethird".
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
@@ -302,7 +371,7 @@ object RandomForest extends Serializable with Logging {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
- rf.train(input)
+ rf.run(input)
}
/**
@@ -318,8 +387,7 @@ object RandomForest extends Serializable with Logging {
* Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
- * if numTrees > 1 (forest) set to "sqrt" for classification and
- * to "onethird" for regression.
+ * if numTrees > 1 (forest) set to "onethird".
* @param impurity Criterion used for information gain calculation.
* Supported values: "variance".
* @param maxDepth Maximum depth of the tree.
@@ -328,7 +396,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return a random forest model that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
@@ -448,5 +516,4 @@ object RandomForest extends Serializable with Logging {
3 * totalBins
}
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
new file mode 100644
index 0000000000000..e703adbdbfbb3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import scala.beans.BeanProperty
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
+
+/**
+ * :: Experimental ::
+ * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]].
+ *
+ * @param treeStrategy Parameters for the tree algorithm. We support regression and binary
+ * classification for boosting. Impurity setting will be ignored.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param numIterations Number of iterations of boosting. In other words, the number of
+ * weak hypotheses used in the final model.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ */
+@Experimental
+case class BoostingStrategy(
+ // Required boosting parameters
+ @BeanProperty var treeStrategy: Strategy,
+ @BeanProperty var loss: Loss,
+ // Optional boosting parameters
+ @BeanProperty var numIterations: Int = 100,
+ @BeanProperty var learningRate: Double = 0.1) extends Serializable {
+
+ /**
+ * Check validity of parameters.
+ * Throws exception if invalid.
+ */
+ private[tree] def assertValid(): Unit = {
+ treeStrategy.algo match {
+ case Classification =>
+ require(treeStrategy.numClassesForClassification == 2,
+ "Only binary classification is supported for boosting.")
+ case Regression =>
+ // nothing
+ case _ =>
+ throw new IllegalArgumentException(
+ s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." +
+ s" Valid settings are: Classification, Regression.")
+ }
+ require(learningRate > 0 && learningRate <= 1,
+ "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.")
+ }
+}
+
+@Experimental
+object BoostingStrategy {
+
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: String): BoostingStrategy = {
+ val treeStrategy = Strategy.defaultStrategy(algo)
+ treeStrategy.maxDepth = 3
+ algo match {
+ case "Classification" =>
+ treeStrategy.numClassesForClassification = 2
+ new BoostingStrategy(treeStrategy, LogLoss)
+ case "Regression" =>
+ new BoostingStrategy(treeStrategy, SquaredError)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
new file mode 100644
index 0000000000000..b5bf732d1b33a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+/**
+ * Enum to select ensemble combining strategy for base learners
+ */
+private[tree] object EnsembleCombiningStrategy extends Enumeration {
+ type EnsembleCombiningStrategy = Value
+ val Average, Sum, Vote = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index caaccbfb8ad16..d75f38433c081 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.tree.configuration
+import scala.beans.BeanProperty
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
@@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* for choosing how to split on features at each node.
* More bins give higher granularity.
* @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
- * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
+ * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
@@ -58,31 +59,35 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
+ * maintain a separate RDD of node Id cache for each row.
+ * @param checkpointDir If the node Id cache is used, it will help to checkpoint
+ * the node Id cache periodically. This is the checkpoint directory
+ * to be used for the node Id cache.
+ * @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
+ * E.g. 10 means that the cache will get checkpointed every 10 updates.
*/
@Experimental
class Strategy (
- val algo: Algo,
- val impurity: Impurity,
- val maxDepth: Int,
- val numClassesForClassification: Int = 2,
- val maxBins: Int = 32,
- val quantileCalculationStrategy: QuantileStrategy = Sort,
- val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
- val minInstancesPerNode: Int = 1,
- val minInfoGain: Double = 0.0,
- val maxMemoryInMB: Int = 256) extends Serializable {
+ @BeanProperty var algo: Algo,
+ @BeanProperty var impurity: Impurity,
+ @BeanProperty var maxDepth: Int,
+ @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var maxBins: Int = 32,
+ @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
+ @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ @BeanProperty var minInstancesPerNode: Int = 1,
+ @BeanProperty var minInfoGain: Double = 0.0,
+ @BeanProperty var maxMemoryInMB: Int = 256,
+ @BeanProperty var subsamplingRate: Double = 1,
+ @BeanProperty var useNodeIdCache: Boolean = false,
+ @BeanProperty var checkpointDir: Option[String] = None,
+ @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
- if (algo == Classification) {
- require(numClassesForClassification >= 2)
- }
- require(minInstancesPerNode >= 1,
- s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
- require(maxMemoryInMB <= 10240,
- s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
-
- val isMulticlassClassification =
+ def isMulticlassClassification =
algo == Classification && numClassesForClassification > 2
- val isMulticlassWithCategoricalFeatures
+ def isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
/**
@@ -99,6 +104,23 @@ class Strategy (
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}
+ /**
+ * Sets Algorithm using a String.
+ */
+ def setAlgo(algo: String): Unit = algo match {
+ case "Classification" => setAlgo(Classification)
+ case "Regression" => setAlgo(Regression)
+ }
+
+ /**
+ * Sets categoricalFeaturesInfo using a Java Map.
+ */
+ def setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = {
+ setCategoricalFeaturesInfo(
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ }
+
/**
* Check validity of parameters.
* Throws exception if invalid.
@@ -130,6 +152,33 @@ class Strategy (
s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" +
s" feature $feature has $arity categories. The number of categories should be >= 2.")
}
+ require(minInstancesPerNode >= 1,
+ s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+ require(maxMemoryInMB <= 10240,
+ s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
}
+ /** Returns a shallow copy of this instance. */
+ def copy: Strategy = {
+ new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+ quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
+ maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
+ }
+}
+
+@Experimental
+object Strategy {
+
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo "Classification" or "Regression"
+ */
+ def defaultStrategy(algo: String): Strategy = algo match {
+ case "Classification" =>
+ new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
+ numClassesForClassification = 2)
+ case "Regression" =>
+ new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
+ numClassesForClassification = 0)
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
index 937c8a2ac5836..089010c81ffb6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -17,18 +17,18 @@
package org.apache.spark.mllib.tree.impl
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
/**
* Internal representation of a datapoint which belongs to several subsamples of the same dataset,
* particularly for bagging (e.g., for random forests).
*
* This holds one instance, as well as an array of weights which represent the (weighted)
- * number of times which this instance appears in each subsample.
+ * number of times which this instance appears in each subsamplingRate.
* E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
*
@@ -45,27 +45,71 @@ private[tree] object BaggedPoint {
/**
* Convert an input dataset into its BaggedPoint representation,
- * choosing subsample counts for each instance.
- * Each subsample has the same number of instances as the original dataset,
- * and is created by subsampling with replacement.
- * @param input Input dataset.
- * @param numSubsamples Number of subsamples of this RDD to take.
- * @param seed Random seed.
- * @return BaggedPoint dataset representation
+ * choosing subsamplingRate counts for each instance.
+ * Each subsamplingRate has the same number of instances as the original dataset,
+ * and is created by subsampling without replacement.
+ * @param input Input dataset.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param numSubsamples Number of subsamples of this RDD to take.
+ * @param withReplacement Sampling with/without replacement.
+ * @param seed Random seed.
+ * @return BaggedPoint dataset representation.
*/
- def convertToBaggedRDD[Datum](
+ def convertToBaggedRDD[Datum] (
input: RDD[Datum],
+ subsamplingRate: Double,
numSubsamples: Int,
+ withReplacement: Boolean,
seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+ if (withReplacement) {
+ convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
+ } else {
+ if (numSubsamples == 1 && subsamplingRate == 1.0) {
+ convertToBaggedRDDWithoutSampling(input)
+ } else {
+ convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
+ input: RDD[Datum],
+ subsamplingRate: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+ // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+ val rng = new XORShiftRandom
+ rng.setSeed(seed + partitionIndex + 1)
+ instances.map { instance =>
+ val subsampleWeights = new Array[Double](numSubsamples)
+ var subsampleIndex = 0
+ while (subsampleIndex < numSubsamples) {
+ val x = rng.nextDouble()
+ subsampleWeights(subsampleIndex) = {
+ if (x < subsamplingRate) 1.0 else 0.0
+ }
+ subsampleIndex += 1
+ }
+ new BaggedPoint(instance, subsampleWeights)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithReplacement[Datum] (
+ input: RDD[Datum],
+ subsample: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
- // TODO: Support different sampling rates, and sampling without replacement.
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
- val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1))
+ val poisson = new PoissonDistribution(subsample)
+ poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
- subsampleWeights(subsampleIndex) = poisson.nextInt()
+ subsampleWeights(subsampleIndex) = poisson.sample()
subsampleIndex += 1
}
new BaggedPoint(instance, subsampleWeights)
@@ -73,7 +117,8 @@ private[tree] object BaggedPoint {
}
}
- def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+ private def convertToBaggedRDDWithoutSampling[Datum] (
+ input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
input.map(datum => new BaggedPoint(datum, Array(1.0)))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 55f422dff0d71..ce8825cc03229 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -64,12 +64,6 @@ private[tree] class DTStatsAggregator(
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}
- /**
- * Indicator for each feature of whether that feature is an unordered feature.
- * TODO: Is Array[Boolean] any faster?
- */
- def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
-
/**
* Total number of elements stored in this aggregator
*/
@@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator(
* Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
- def getFeatureOffset(featureIndex: Int): Int = {
- require(!isUnordered(featureIndex),
- s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
- s" for unordered feature $featureIndex.")
- featureOffsets(featureIndex)
- }
+ def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
/**
* Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
- require(isUnordered(featureIndex),
- s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
- s" but was called for ordered feature $featureIndex.")
val baseOffset = featureOffsets(featureIndex)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 212dce25236e0..5bc0f2635c6b1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl
import scala.collection.mutable
+import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -75,6 +76,17 @@ private[tree] class DecisionTreeMetadata(
numBins(featureIndex) - 1
}
+
+ /**
+ * Set number of splits for a continuous feature.
+ * For a continuous feature, number of bins is number of splits plus 1.
+ */
+ def setNumSplits(featureIndex: Int, numSplits: Int) {
+ require(isContinuous(featureIndex),
+ s"Only number of bin for a continuous feature can be set.")
+ numBins(featureIndex) = numSplits + 1
+ }
+
/**
* Indicates if feature subsampling is being used.
*/
@@ -82,7 +94,7 @@ private[tree] class DecisionTreeMetadata(
}
-private[tree] object DecisionTreeMetadata {
+private[tree] object DecisionTreeMetadata extends Logging {
/**
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
@@ -103,6 +115,10 @@ private[tree] object DecisionTreeMetadata {
}
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
+ if (maxPossibleBins < strategy.maxBins) {
+ logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
+ s" (= number of training instances)")
+ }
// We check the number of bins here against maxPossibleBins.
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000000000..83011b48b7d9b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
+
+/**
+ * :: DeveloperApi ::
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+@DeveloperApi
+private[tree] case class NodeIndexUpdater(
+ split: Split,
+ nodeIndex: Int) {
+ /**
+ * Determine a child node index based on the feature value and the split.
+ * @param binnedFeatures Binned feature values.
+ * @param bins Bin information to convert the bin indices to approximate feature values.
+ * @return Child node index to update to.
+ */
+ def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
+ if (split.featureType == Continuous) {
+ val featureIndex = split.feature
+ val binIndex = binnedFeatures(featureIndex)
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+ if (featureValueUpperBound <= split.threshold) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ } else {
+ if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * A given TreePoint would belong to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ * (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointDir The checkpoint directory where
+ * the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ */
+@DeveloperApi
+private[tree] class NodeIdCache(
+ var nodeIdsForInstances: RDD[Array[Int]],
+ val checkpointDir: Option[String],
+ val checkpointInterval: Int) {
+
+ // Keep a reference to a previous node Ids for instances.
+ // Because we will keep on re-persisting updated node Ids,
+ // we want to unpersist the previous RDD.
+ private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+ // To keep track of the past checkpointed RDDs.
+ private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+ private var rddUpdateCount = 0
+
+ // If a checkpoint directory is given, and there's no prior checkpoint directory,
+ // then set the checkpoint directory with the given one.
+ if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
+ nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
+ }
+
+ /**
+ * Update the node index values in the cache.
+ * This updates the RDD and its lineage.
+ * TODO: Passing bin information to executors seems unnecessary and costly.
+ * @param data The RDD of training rows.
+ * @param nodeIdUpdaters A map of node index updaters.
+ * The key is the indices of nodes that we want to update.
+ * @param bins Bin information needed to find child node indices.
+ */
+ def updateNodeIndices(
+ data: RDD[BaggedPoint[TreePoint]],
+ nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+ bins: Array[Array[Bin]]): Unit = {
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist()
+ }
+
+ prevNodeIdsForInstances = nodeIdsForInstances
+ nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
+ dataPoint => {
+ var treeId = 0
+ while (treeId < nodeIdUpdaters.length) {
+ val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
+ if (nodeIdUpdater != null) {
+ val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+ binnedFeatures = dataPoint._1.datum.binnedFeatures,
+ bins = bins)
+ dataPoint._2(treeId) = newNodeIndex
+ }
+
+ treeId += 1
+ }
+
+ dataPoint._2
+ }
+ }
+
+ // Keep on persisting new ones.
+ nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+ rddUpdateCount += 1
+
+ // Handle checkpointing if the directory is not None.
+ if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
+ (rddUpdateCount % checkpointInterval) == 0) {
+ // Let's see if we can delete previous checkpoints.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // We can delete the oldest checkpoint iff
+ // the next checkpoint actually exists in the file system.
+ if (checkpointQueue.get(1).get.getCheckpointFile != None) {
+ val old = checkpointQueue.dequeue()
+
+ // Since the old checkpoint is not deleted by Spark,
+ // we'll manually delete it here.
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ } else {
+ canDelete = false
+ }
+ }
+
+ nodeIdsForInstances.checkpoint()
+ checkpointQueue.enqueue(nodeIdsForInstances)
+ }
+ }
+
+ /**
+ * Call this after training is finished to delete any remaining checkpoints.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.size > 0) {
+ val old = checkpointQueue.dequeue()
+ if (old.getCheckpointFile != None) {
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ }
+ }
+ }
+}
+
+@DeveloperApi
+private[tree] object NodeIdCache {
+ /**
+ * Initialize the node Id cache with initial node Id values.
+ * @param data The RDD of training rows.
+ * @param numTrees The number of trees that we want to create cache for.
+ * @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ * @param initVal The initial values in the cache.
+ * @return A node Id cache containing an RDD of initial root node Indices.
+ */
+ def init(
+ data: RDD[BaggedPoint[TreePoint]],
+ numTrees: Int,
+ checkpointDir: Option[String],
+ checkpointInterval: Int,
+ initVal: Int = 1): NodeIdCache = {
+ new NodeIdCache(
+ data.map(_ => Array.fill[Int](numTrees)(initVal)),
+ checkpointDir,
+ checkpointInterval)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
new file mode 100644
index 0000000000000..d1bde15e6b150
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for absolute error loss calculation (for regression).
+ *
+ * The absolute (L1) error is defined as:
+ * |y - F(x)|
+ * where y is the label and F(x) is the model prediction for features x.
+ */
+@DeveloperApi
+object AbsoluteError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * absolute error calculation.
+ * The gradient with respect to F(x) is: sign(F(x) - y)
+ * @param model Ensemble model
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: TreeEnsembleModel,
+ point: LabeledPoint): Double = {
+ if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
+ }
+
+ /**
+ * Method to calculate loss of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Ensemble model
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return Mean absolute error of model on data
+ */
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = model.predict(y.features) - y.label
+ math.abs(err)
+ }.mean()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
new file mode 100644
index 0000000000000..7ce9fa6f86c42
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for log loss calculation (for classification).
+ * This uses twice the binomial negative log likelihood, called "deviance" in Friedman (1999).
+ *
+ * The log loss is defined as:
+ * 2 log(1 + exp(-2 y F(x)))
+ * where y is a label in {-1, 1} and F(x) is the model prediction for features x.
+ */
+@DeveloperApi
+object LogLoss extends Loss {
+
+ /**
+ * Method to calculate the loss gradients for the gradient boosting calculation for binary
+ * classification
+ * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
+ * @param model Ensemble model
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: TreeEnsembleModel,
+ point: LabeledPoint): Double = {
+ val prediction = model.predict(point.features)
+ - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
+ }
+
+ /**
+ * Method to calculate loss of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Ensemble model
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return Mean log loss of model on data
+ */
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map { case point =>
+ val prediction = model.predict(point.features)
+ val margin = 2.0 * point.label * prediction
+ // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
+ // stable.
+ if (margin >= 0) {
+ 2.0 * math.log1p(math.exp(-margin))
+ } else {
+ 2.0 * (-margin + math.log1p(math.exp(margin)))
+ }
+ }.mean()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
new file mode 100644
index 0000000000000..4bca9039ebe1d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
+ */
+@DeveloperApi
+trait Loss extends Serializable {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation.
+ * @param model Model of the weak learner.
+ * @param point Instance of the training dataset.
+ * @return Loss gradient.
+ */
+ def gradient(
+ model: TreeEnsembleModel,
+ point: LabeledPoint): Double
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
new file mode 100644
index 0000000000000..42c9ead9884b4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+object Losses {
+
+ def fromString(name: String): Loss = name match {
+ case "leastSquaresError" => SquaredError
+ case "leastAbsoluteError" => AbsoluteError
+ case "logLoss" => LogLoss
+ case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name")
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
new file mode 100644
index 0000000000000..50ecaa2f86f35
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.TreeEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for squared error loss calculation.
+ *
+ * The squared (L2) error is defined as:
+ * (y - F(x))**2
+ * where y is the label and F(x) is the model prediction for features x.
+ */
+@DeveloperApi
+object SquaredError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * squares error calculation.
+ * The gradient with respect to F(x) is: - 2 (y - F(x))
+ * @param model Ensemble model
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: TreeEnsembleModel,
+ point: LabeledPoint): Double = {
+ 2.0 * (model.predict(point.features) - point.label)
+ }
+
+ /**
+ * Method to calculate loss of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Ensemble model
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return Mean squared error of model on data
+ */
+ override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = model.predict(y.features) - y.label
+ err * err
+ }.mean()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ec1d99ab26f9c..a5760963068c3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -18,9 +18,10 @@
package org.apache.spark.mllib.tree.model
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
-import org.apache.spark.mllib.linalg.Vector
/**
* :: Experimental ::
@@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
features.map(x => predict(x))
}
+
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param features JavaRDD representing data points to be predicted
+ * @return JavaRDD of predictions for each of the given data points
+ */
+ def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
+ predict(features.rdd)
+ }
+
/**
* Get number of nodes in tree, including leaf nodes.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index a89e71e115806..9a50ecb550c38 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
+ * @param leftPredict left node predict
+ * @param rightPredict right node predict
*/
@DeveloperApi
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
- val rightImpurity: Double) extends Serializable {
+ val rightImpurity: Double,
+ val leftPredict: Predict,
+ val rightPredict: Predict) extends Serializable {
override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
@@ -58,5 +62,6 @@ private[tree] object InformationGainStats {
* denote that current split doesn't satisfies minimum info gain or
* minimum number of instances per node.
*/
- val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+ val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
+ new Predict(0.0, 0.0), new Predict(0.0, 0.0))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 56c3e25d9285f..2179da8dbe03e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector
*
* @param id integer node id, from 1
* @param predict predicted value at the node
- * @param isLeaf whether the leaf is a node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
* @param split split to calculate left and right nodes
* @param leftNode left child
* @param rightNode right child
@@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector
@DeveloperApi
class Node (
val id: Int,
- var predict: Double,
+ var predict: Predict,
+ var impurity: Double,
var isLeaf: Boolean,
var split: Option[Split],
var leftNode: Option[Node],
@@ -49,7 +51,7 @@ class Node (
var stats: Option[InformationGainStats]) extends Serializable with Logging {
override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
- "split = " + split + ", stats = " + stats
+ "impurity = " + impurity + "split = " + split + ", stats = " + stats
/**
* build the left node and right nodes if not leaf
@@ -62,6 +64,7 @@ class Node (
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
+ logDebug("impurity = " + impurity)
if (!isLeaf) {
leftNode = Some(nodes(Node.leftChildIndex(id)))
rightNode = Some(nodes(Node.rightChildIndex(id)))
@@ -77,7 +80,7 @@ class Node (
*/
def predict(features: Vector) : Double = {
if (isLeaf) {
- predict
+ predict.predict
} else{
if (split.get.featureType == Continuous) {
if (features(split.get.feature) <= split.get.threshold) {
@@ -109,7 +112,7 @@ class Node (
} else {
Some(rightNode.get.deepCopy())
}
- new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+ new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
}
/**
@@ -154,7 +157,7 @@ class Node (
}
val prefix: String = " " * indentFactor
if (isLeaf) {
- prefix + s"Predict: $predict\n"
+ prefix + s"Predict: ${predict.predict}\n"
} else {
prefix + s"If ${splitToString(split.get, left=true)}\n" +
leftNode.get.subtreeToString(indentFactor + 1) +
@@ -170,7 +173,27 @@ private[tree] object Node {
/**
* Return a node with the given node id (but nothing else set).
*/
- def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+ def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0,
+ false, None, None, None, None)
+
+ /**
+ * Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
+ * This is used in `DecisionTree.findBestSplits` to construct child nodes
+ * after finding the best splits for parent nodes.
+ * Other fields are set at next level.
+ * @param nodeIndex integer node id, from 1
+ * @param predict predicted value at the node
+ * @param impurity current node impurity
+ * @param isLeaf whether the node is a leaf
+ * @return new node instance
+ */
+ def apply(
+ nodeIndex: Int,
+ predict: Predict,
+ impurity: Double,
+ isLeaf: Boolean): Node = {
+ new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
+ }
/**
* Return the index of the left child of this node.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index d8476b5cd7bc7..004838ee5ba0e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -17,12 +17,15 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.annotation.DeveloperApi
+
/**
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
-private[tree] class Predict(
+@DeveloperApi
+class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
deleted file mode 100644
index 4d66d6d81caa5..0000000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import scala.collection.mutable
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.rdd.RDD
-
-/**
- * :: Experimental ::
- * Random forest model for classification or regression.
- * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
- * aggregate predictions.
- * @param trees Trees which make up this forest. This cannot be empty.
- * @param algo algorithm type -- classification or regression
- */
-@Experimental
-class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable {
-
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
-
- /**
- * Predict values for a single data point.
- *
- * @param features array representing a single data point
- * @return Double prediction from the trained model
- */
- def predict(features: Vector): Double = {
- algo match {
- case Classification =>
- val predictionToCount = new mutable.HashMap[Int, Int]()
- trees.foreach { tree =>
- val prediction = tree.predict(features).toInt
- predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
- }
- predictionToCount.maxBy(_._2)._1
- case Regression =>
- trees.map(_.predict(features)).sum / trees.size
- }
- }
-
- /**
- * Predict values for the given data set.
- *
- * @param features RDD representing data points to be predicted
- * @return RDD[Double] where each entry contains the corresponding prediction
- */
- def predict(features: RDD[Vector]): RDD[Double] = {
- features.map(x => predict(x))
- }
-
- /**
- * Get number of trees in forest.
- */
- def numTrees: Int = trees.size
-
- /**
- * Get total number of nodes, summed over all trees in the forest.
- */
- def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
-
- /**
- * Print a summary of the model.
- */
- override def toString: String = algo match {
- case Classification =>
- s"RandomForestModel classifier with $numTrees trees"
- case Regression =>
- s"RandomForestModel regressor with $numTrees trees"
- case _ => throw new IllegalArgumentException(
- s"RandomForestModel given unknown algo parameter: $algo.")
- }
-
- /**
- * Print the full model to a string.
- */
- def toDebugString: String = {
- val header = toString + "\n"
- header + trees.zipWithIndex.map { case (tree, treeIndex) =>
- s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
- }.fold("")(_ + _)
- }
-
-}
-
-private[tree] object RandomForestModel {
-
- def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
- val algo: Algo = trees(0).algo
- require(trees.forall(_.algo == algo),
- "RandomForestModel cannot combine trees which have different output types" +
- " (classification/regression).")
- new RandomForestModel(trees, algo)
- }
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
new file mode 100644
index 0000000000000..22997110de8dd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ * Represents a random forest model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ */
+@Experimental
+class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
+ extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+ combiningStrategy = if (algo == Classification) Vote else Average) {
+
+ require(trees.forall(_.algo == algo))
+}
+
+/**
+ * :: Experimental ::
+ * Represents a gradient boosted trees model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ */
+@Experimental
+class GradientBoostedTreesModel(
+ override val algo: Algo,
+ override val trees: Array[DecisionTreeModel],
+ override val treeWeights: Array[Double])
+ extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+
+ require(trees.size == treeWeights.size)
+}
+
+/**
+ * Represents a tree ensemble model.
+ *
+ * @param algo algorithm for the ensemble model, either Classification or Regression
+ * @param trees tree ensembles
+ * @param treeWeights tree ensemble weights
+ * @param combiningStrategy strategy for combining the predictions, not used for regression.
+ */
+private[tree] sealed class TreeEnsembleModel(
+ protected val algo: Algo,
+ protected val trees: Array[DecisionTreeModel],
+ protected val treeWeights: Array[Double],
+ protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+ require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.")
+
+ private val sumWeights = math.max(treeWeights.sum, 1e-15)
+
+ /**
+ * Predicts for a single data point using the weighted sum of ensemble predictions.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictBySumming(features: Vector): Double = {
+ val treePredictions = trees.map(_.predict(features))
+ blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
+ }
+
+ /**
+ * Classifies a single data point based on (weighted) majority votes.
+ */
+ private def predictByVoting(features: Vector): Double = {
+ val votes = mutable.Map.empty[Int, Double]
+ trees.view.zip(treeWeights).foreach { case (tree, weight) =>
+ val prediction = tree.predict(features).toInt
+ votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
+ }
+ votes.maxBy(_._2)._1
+ }
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ def predict(features: Vector): Double = {
+ (algo, combiningStrategy) match {
+ case (Regression, Sum) =>
+ predictBySumming(features)
+ case (Regression, Average) =>
+ predictBySumming(features) / sumWeights
+ case (Classification, Sum) => // binary classification
+ val prediction = predictBySumming(features)
+ // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+ if (prediction > 0.0) 1.0 else 0.0
+ case (Classification, Vote) =>
+ predictByVoting(features)
+ case _ =>
+ throw new IllegalArgumentException(
+ "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " +
+ s"($algo, $combiningStrategy).")
+ }
+ }
+
+ /**
+ * Predict values for the given data set.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+ /**
+ * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]].
+ */
+ def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
+ predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
+ }
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = {
+ algo match {
+ case Classification =>
+ s"TreeEnsembleModel classifier with $numTrees trees\n"
+ case Regression =>
+ s"TreeEnsembleModel regressor with $numTrees trees\n"
+ case _ => throw new IllegalArgumentException(
+ s"TreeEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + trees.zipWithIndex.map { case (tree, treeIndex) =>
+ s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /**
+ * Get number of trees in forest.
+ */
+ def numTrees: Int = trees.size
+
+ /**
+ * Get total number of nodes, summed over all trees in the forest.
+ */
+ def totalNumNodes: Int = trees.map(_.numNodes).sum
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index ca35100aa99c6..9353351af72a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
-import org.apache.spark.util.random.BernoulliSampler
+import org.apache.spark.util.random.BernoulliCellSampler
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.storage.StorageLevel
@@ -76,7 +76,7 @@ object MLUtils {
.map { line =>
val items = line.split(' ')
val label = items.head.toDouble
- val (indices, values) = items.tail.map { item =>
+ val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
val value = indexAndValue(1).toDouble
@@ -196,8 +196,8 @@ object MLUtils {
/**
* Load labeled data from a file. The data format used here is
- * , ...
- * where , are feature values in Double and is the corresponding label as Double.
+ * L, f1 f2 ...
+ * where f1, f2 are feature values in Double and L is the corresponding label as Double.
*
* @param sc SparkContext
* @param dir Directory to the input data files.
@@ -219,8 +219,8 @@ object MLUtils {
/**
* Save labeled data to a file. The data format used here is
- * , ...
- * where , are feature values in Double and is the corresponding label as Double.
+ * L, f1 f2 ...
+ * where f1, f2 are feature values in Double and L is the corresponding label as Double.
*
* @param data An RDD of LabeledPoints containing data to be saved.
* @param dir Directory to save the data.
@@ -244,7 +244,7 @@ object MLUtils {
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
- val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
+ val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed)
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed)
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
new file mode 100644
index 0000000000000..42846677ed285
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.feature.StandardScaler;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+/**
+ * Test Pipeline construction and fitting in Java.
+ */
+public class JavaPipelineSuite {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaPipelineSuite");
+ jsql = new JavaSQLContext(jsc);
+ JavaRDD points =
+ jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
+ dataset = jsql.applySchema(points, LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void pipeline() {
+ StandardScaler scaler = new StandardScaler()
+ .setInputCol("features")
+ .setOutputCol("scaledFeatures");
+ LogisticRegression lr = new LogisticRegression()
+ .setFeaturesCol("scaledFeatures");
+ Pipeline pipeline = new Pipeline()
+ .setStages(new PipelineStage[] {scaler, lr});
+ PipelineModel model = pipeline.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
new file mode 100644
index 0000000000000..76eb7f00329f2
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+public class JavaLogisticRegressionSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ jsql = new JavaSQLContext(jsc);
+ List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void logisticRegression() {
+ LogisticRegression lr = new LogisticRegression();
+ LogisticRegressionModel model = lr.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+
+ @Test
+ public void logisticRegressionWithSetters() {
+ LogisticRegression lr = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0);
+ LogisticRegressionModel model = lr.fit(dataset);
+ model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
+ .registerTempTable("prediction");
+ JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collect();
+ }
+
+ @Test
+ public void logisticRegressionFitWithVarargs() {
+ LogisticRegression lr = new LogisticRegression();
+ lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
new file mode 100644
index 0000000000000..a266ebd2071a1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.LogisticRegression;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.api.java.JavaSQLContext;
+import org.apache.spark.sql.api.java.JavaSchemaRDD;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+
+public class JavaCrossValidatorSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient JavaSQLContext jsql;
+ private transient JavaSchemaRDD dataset;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
+ jsql = new JavaSQLContext(jsc);
+ List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void crossValidationWithLogisticRegression() {
+ LogisticRegression lr = new LogisticRegression();
+ ParamMap[] lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam(), new double[] {0.001, 1000.0})
+ .addGrid(lr.maxIter(), new int[] {0, 10})
+ .build();
+ BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator();
+ CrossValidator cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(3);
+ CrossValidatorModel cvModel = cv.fit(dataset);
+ ParamMap bestParamMap = cvModel.bestModel().fittingParamMap();
+ Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam()));
+ Assert.assertEquals(10, bestParamMap.apply(lr.maxIter()));
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index f6ca9643227f8..af688c504cf1e 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -23,13 +23,14 @@
import scala.Tuple2;
import scala.Tuple3;
+import com.google.common.collect.Lists;
import org.jblas.DoubleMatrix;
-
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -47,61 +48,48 @@ public void tearDown() {
sc = null;
}
- static void validatePrediction(
+ void validatePrediction(
MatrixFactorizationModel model,
int users,
int products,
- int features,
DoubleMatrix trueRatings,
double matchThreshold,
boolean implicitPrefs,
DoubleMatrix truePrefs) {
- DoubleMatrix predictedU = new DoubleMatrix(users, features);
- List> userFeatures = model.userFeatures().toJavaRDD().collect();
- for (int i = 0; i < features; ++i) {
- for (Tuple2