diff --git a/.gitignore b/.gitignore index 3a68abd955b22..cd9f90d55932c 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ spark-*-bin.tar.gz unit-tests.log /lib/ rat-results.txt +scalastyle.txt diff --git a/bin/spark-shell b/bin/spark-shell index 861ab606540cd..fac006cf492ed 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -30,67 +30,189 @@ esac # Enter posix mode for bash set -o posix -CORE_PATTERN="^[0-9]+$" -MEM_PATTERN="^[0-9]+[m|g|M|G]$" - +## Global script variables FWDIR="$(cd `dirname $0`/..; pwd)" -if [ "$1" = "--help" ] || [ "$1" = "-h" ]; then - echo "Usage: spark-shell [OPTIONS]" - echo "OPTIONS:" - echo "-c --cores num, the maximum number of cores to be used by the spark shell" - echo "-em --execmem num[m|g], the memory used by each executor of spark shell" - echo "-dm --drivermem num[m|g], the memory used by the spark shell and driver" - echo "-h --help, print this help information" - exit -fi +SPARK_REPL_OPTS="${SPARK_REPL_OPTS:-""}" +DEFAULT_MASTER="local" +MASTER=${MASTER:-""} + +info_log=0 + +#CLI Color Templates +txtund=$(tput sgr 0 1) # Underline +txtbld=$(tput bold) # Bold +bldred=${txtbld}$(tput setaf 1) # red +bldyel=${txtbld}$(tput setaf 3) # yellow +bldblu=${txtbld}$(tput setaf 4) # blue +bldwht=${txtbld}$(tput setaf 7) # white +txtrst=$(tput sgr0) # Reset +info=${bldwht}*${txtrst} # Feedback +pass=${bldblu}*${txtrst} +warn=${bldred}*${txtrst} +ques=${bldblu}?${txtrst} + +# Helper function to describe the script usage +function usage() { + cat << EOF +${txtbld}Usage${txtrst}: spark-shell [OPTIONS] + +${txtbld}OPTIONS${txtrst}: + -h --help : Print this help information. + -c --cores : The maximum number of cores to be used by the Spark Shell. + -em --executor-memory : The memory used by each executor of the Spark Shell, the number + is followed by m for megabytes or g for gigabytes, e.g. "1g". + -dm --driver-memory : The memory used by the Spark Shell, the number is followed + by m for megabytes or g for gigabytes, e.g. "1g". + -m --master : A full string that describes the Spark Master, defaults to "local" + e.g. "spark://localhost:7077". + --log-conf : Enables logging of the supplied SparkConf as INFO at start of the + Spark Context. + +e.g. + spark-shell -m spark://localhost:7077 -c 4 -dm 512m -em 2g + +EOF +} + +function out_error(){ + echo -e "${txtund}${bldred}ERROR${txtrst}: $1" + usage + exit 1 +} + +function log_info(){ + [ $info_log -eq 1 ] && echo -e "${bldyel}INFO${txtrst}: $1" +} + +function log_warn(){ + echo -e "${txtund}${bldyel}WARN${txtrst}: $1" +} -for o in "$@"; do - if [ "$1" = "-c" -o "$1" = "--cores" ]; then - shift +# PATTERNS used to validate more than one optional arg. +ARG_FLAG_PATTERN="^-" +MEM_PATTERN="^[0-9]+[m|g|M|G]$" +NUM_PATTERN="^[0-9]+$" +PORT_PATTERN="^[0-9]+$" + +# Setters for optional args. +function set_cores(){ + CORE_PATTERN="^[0-9]+$" if [[ "$1" =~ $CORE_PATTERN ]]; then - SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.cores.max=$1" - shift + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.cores.max=$1" else - echo "ERROR: wrong format for -c/--cores" - exit 1 + out_error "wrong format for $2" fi - fi - if [ "$1" = "-em" -o "$1" = "--execmem" ]; then - shift +} + +function set_em(){ if [[ $1 =~ $MEM_PATTERN ]]; then SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.executor.memory=$1" - shift else - echo "ERROR: wrong format for --execmem/-em" - exit 1 + out_error "wrong format for $2" fi - fi - if [ "$1" = "-dm" -o "$1" = "--drivermem" ]; then - shift +} + +function set_dm(){ if [[ $1 =~ $MEM_PATTERN ]]; then export SPARK_DRIVER_MEMORY=$1 - shift else - echo "ERROR: wrong format for --drivermem/-dm" - exit 1 + out_error "wrong format for $2" fi - fi -done +} + +function set_spark_log_conf(){ + SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Dspark.logConf=$1" +} -# Set MASTER from spark-env if possible -DEFAULT_SPARK_MASTER_PORT=7077 -if [ -z "$MASTER" ]; then - . $FWDIR/bin/load-spark-env.sh - if [ "x" != "x$SPARK_MASTER_IP" ]; then - if [ "y" != "y$SPARK_MASTER_PORT" ]; then - SPARK_MASTER_PORT="${SPARK_MASTER_PORT}" +function set_spark_master(){ + if ! [[ "$1" =~ $ARG_FLAG_PATTERN ]]; then + MASTER="$1" else - SPARK_MASTER_PORT=$DEFAULT_SPARK_MASTER_PORT + out_error "wrong format for $2" + fi +} + +function resolve_spark_master(){ + # Set MASTER from spark-env if possible + DEFAULT_SPARK_MASTER_PORT=7077 + if [ -z "$MASTER" ]; then + . $FWDIR/bin/load-spark-env.sh + if [ -n "$SPARK_MASTER_IP" ]; then + SPARK_MASTER_PORT="${SPARK_MASTER_PORT:-"$DEFAULT_SPARK_MASTER_PORT"}" + export MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}" + fi + fi + + if [ -z "$MASTER" ]; then + MASTER="$DEFAULT_MASTER" fi - export MASTER="spark://${SPARK_MASTER_IP}:${SPARK_MASTER_PORT}" - fi -fi + +} + +function main(){ + log_info "Base Directory set to $FWDIR" + + resolve_spark_master + log_info "Spark Master is $MASTER" + + log_info "Spark REPL options $SPARK_REPL_OPTS" + if $cygwin; then + # Workaround for issue involving JLine and Cygwin + # (see http://sourceforge.net/p/jline/bugs/40/). + # If you're using the Mintty terminal emulator in Cygwin, may need to set the + # "Backspace sends ^H" setting in "Keys" section of the Mintty options + # (see https://github.com/sbt/sbt/issues/562). + stty -icanon min 1 -echo > /dev/null 2>&1 + export SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Djline.terminal=unix" + $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" + stty icanon echo > /dev/null 2>&1 + else + export SPARK_REPL_OPTS + $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" + fi +} + +for option in "$@" +do + case $option in + -h | --help ) + usage + exit 1 + ;; + -c | --cores) + shift + _1=$1 + shift + set_cores $_1 "-c/--cores" + ;; + -em | --executor-memory) + shift + _1=$1 + shift + set_em $_1 "-em/--executor-memory" + ;; + -dm | --driver-memory) + shift + _1=$1 + shift + set_dm $_1 "-dm/--driver-memory" + ;; + -m | --master) + shift + _1=$1 + shift + set_spark_master $_1 "-m/--master" + ;; + --log-conf) + shift + set_spark_log_conf "true" + info_log=1 + ;; + ?) + ;; + esac +done # Copy restore-TTY-on-exit functions from Scala script so spark-shell exits properly even in # binary distribution of Spark where Scala is not installed @@ -120,22 +242,10 @@ if [[ ! $? ]]; then saved_stty="" fi -if $cygwin; then - # Workaround for issue involving JLine and Cygwin - # (see http://sourceforge.net/p/jline/bugs/40/). - # If you're using the Mintty terminal emulator in Cygwin, may need to set the - # "Backspace sends ^H" setting in "Keys" section of the Mintty options - # (see https://github.com/sbt/sbt/issues/562). - stty -icanon min 1 -echo > /dev/null 2>&1 - export SPARK_REPL_OPTS="$SPARK_REPL_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" - stty icanon echo > /dev/null 2>&1 -else - export SPARK_REPL_OPTS - $FWDIR/bin/spark-class org.apache.spark.repl.Main "$@" -fi +main # record the exit status lest it be overwritten: # then reenable echo and propagate the code. exit_status=$? onExit + diff --git a/bin/spark-submit b/bin/spark-submit new file mode 100755 index 0000000000000..d92d55a032bd5 --- /dev/null +++ b/bin/spark-submit @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +export SPARK_HOME="$(cd `dirname $0`/..; pwd)" +ORIG_ARGS=$@ + +while (($#)); do + if [ $1 = "--deploy-mode" ]; then + DEPLOY_MODE=$2 + elif [ $1 = "--driver-memory" ]; then + DRIVER_MEMORY=$2 + fi + + shift +done + +if [ ! -z $DRIVER_MEMORY ] && [ ! -z $DEPLOY_MODE ] && [ $DEPLOY_MODE = "client" ]; then + export SPARK_MEM=$DRIVER_MEMORY +fi + +$SPARK_HOME/bin/spark-class org.apache.spark.deploy.SparkSubmit $ORIG_ARGS + diff --git a/core/pom.xml b/core/pom.xml index eb6cc4d3105e9..273aa69659336 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -82,10 +82,6 @@ com.google.guava guava - - com.google.code.findbugs - jsr305 - org.slf4j slf4j-api @@ -150,7 +146,7 @@ json4s-jackson_${scala.binary.version} 3.2.6 diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a1af63fa4a391..5ceac28fe7afb 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -81,7 +81,7 @@ class SparkEnv private[spark] ( // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - //actorSystem.awaitTermination() + // actorSystem.awaitTermination() } private[spark] diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 3cd71213769b7..2595c15104e87 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -167,7 +167,7 @@ extends Logging { private var initialized = false private var conf: SparkConf = null def initialize(_isDriver: Boolean, conf: SparkConf) { - TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests + TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { if (!initialized) { initialized = true diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index d9e3035e1ab59..8fd2c7e95b966 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -128,6 +128,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends */ object Client { def main(args: Array[String]) { + println("WARNING: This client is deprecated and will be removed in a future version of Spark.") + println("Use ./bin/spark-submit with \"--master spark://host:port\"") + val conf = new SparkConf() val driverArgs = new ClientArguments(args) diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index a73b459c3cea1..9a7a113c95715 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -66,9 +66,9 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. workerActorSystems.foreach(_.shutdown()) - //workerActorSystems.foreach(_.awaitTermination()) + // workerActorSystems.foreach(_.awaitTermination()) masterActorSystems.foreach(_.shutdown()) - //masterActorSystems.foreach(_.awaitTermination()) + // masterActorSystems.foreach(_.awaitTermination()) masterActorSystems.clear() workerActorSystems.clear() } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala new file mode 100644 index 0000000000000..1fa799190409f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{PrintStream, File} +import java.net.URL + +import org.apache.spark.executor.ExecutorURLClassLoader + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map + +/** + * Scala code behind the spark-submit script. The script handles setting up the classpath with + * relevant Spark dependencies and provides a layer over the different cluster managers and deploy + * modes that Spark supports. + */ +object SparkSubmit { + private val YARN = 1 + private val STANDALONE = 2 + private val MESOS = 4 + private val LOCAL = 8 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL + + private var clusterManager: Int = LOCAL + + def main(args: Array[String]) { + val appArgs = new SparkSubmitArguments(args) + if (appArgs.verbose) { + printStream.println(appArgs) + } + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) + } + + // Exposed for testing + private[spark] var printStream: PrintStream = System.err + private[spark] var exitFn: () => Unit = () => System.exit(-1) + + private[spark] def printErrorAndExit(str: String) = { + printStream.println("error: " + str) + printStream.println("run with --help for more information or --verbose for debugging output") + exitFn() + } + private[spark] def printWarning(str: String) = printStream.println("warning: " + str) + + /** + * @return + * a tuple containing the arguments for the child, a list of classpath + * entries for the child, and the main class for the child + */ + private[spark] def createLaunchEnv(appArgs: SparkSubmitArguments): (ArrayBuffer[String], + ArrayBuffer[String], Map[String, String], String) = { + if (appArgs.master.startsWith("local")) { + clusterManager = LOCAL + } else if (appArgs.master.startsWith("yarn")) { + clusterManager = YARN + } else if (appArgs.master.startsWith("spark")) { + clusterManager = STANDALONE + } else if (appArgs.master.startsWith("mesos")) { + clusterManager = MESOS + } else { + printErrorAndExit("master must start with yarn, mesos, spark, or local") + } + + // Because "yarn-standalone" and "yarn-client" encapsulate both the master + // and deploy mode, we have some logic to infer the master and deploy mode + // from each other if only one is specified, or exit early if they are at odds. + if (appArgs.deployMode == null && appArgs.master == "yarn-standalone") { + appArgs.deployMode = "cluster" + } + if (appArgs.deployMode == "cluster" && appArgs.master == "yarn-client") { + printErrorAndExit("Deploy mode \"cluster\" and master \"yarn-client\" are not compatible") + } + if (appArgs.deployMode == "client" && appArgs.master == "yarn-standalone") { + printErrorAndExit("Deploy mode \"client\" and master \"yarn-standalone\" are not compatible") + } + if (appArgs.deployMode == "cluster" && appArgs.master.startsWith("yarn")) { + appArgs.master = "yarn-standalone" + } + if (appArgs.deployMode != "cluster" && appArgs.master.startsWith("yarn")) { + appArgs.master = "yarn-client" + } + + val deployOnCluster = Option(appArgs.deployMode).getOrElse("client") == "cluster" + + val childClasspath = new ArrayBuffer[String]() + val childArgs = new ArrayBuffer[String]() + val sysProps = new HashMap[String, String]() + var childMainClass = "" + + if (clusterManager == MESOS && deployOnCluster) { + printErrorAndExit("Mesos does not support running the driver on the cluster") + } + + if (!deployOnCluster) { + childMainClass = appArgs.mainClass + childClasspath += appArgs.primaryResource + } else if (clusterManager == YARN) { + childMainClass = "org.apache.spark.deploy.yarn.Client" + childArgs += ("--jar", appArgs.primaryResource) + childArgs += ("--class", appArgs.mainClass) + } + + val options = List[OptionAssigner]( + new OptionAssigner(appArgs.master, ALL_CLUSTER_MGRS, false, sysProp = "spark.master"), + new OptionAssigner(appArgs.driverMemory, YARN, true, clOption = "--driver-memory"), + new OptionAssigner(appArgs.name, YARN, true, clOption = "--name"), + new OptionAssigner(appArgs.queue, YARN, true, clOption = "--queue"), + new OptionAssigner(appArgs.queue, YARN, false, sysProp = "spark.yarn.queue"), + new OptionAssigner(appArgs.numExecutors, YARN, true, clOption = "--num-executors"), + new OptionAssigner(appArgs.numExecutors, YARN, false, sysProp = "spark.executor.instances"), + new OptionAssigner(appArgs.executorMemory, YARN, true, clOption = "--executor-memory"), + new OptionAssigner(appArgs.executorMemory, STANDALONE | MESOS | YARN, false, + sysProp = "spark.executor.memory"), + new OptionAssigner(appArgs.driverMemory, STANDALONE, true, clOption = "--memory"), + new OptionAssigner(appArgs.driverCores, STANDALONE, true, clOption = "--cores"), + new OptionAssigner(appArgs.executorCores, YARN, true, clOption = "--executor-cores"), + new OptionAssigner(appArgs.executorCores, YARN, false, sysProp = "spark.executor.cores"), + new OptionAssigner(appArgs.totalExecutorCores, STANDALONE | MESOS, false, + sysProp = "spark.cores.max"), + new OptionAssigner(appArgs.files, YARN, false, sysProp = "spark.yarn.dist.files"), + new OptionAssigner(appArgs.files, YARN, true, clOption = "--files"), + new OptionAssigner(appArgs.archives, YARN, false, sysProp = "spark.yarn.dist.archives"), + new OptionAssigner(appArgs.archives, YARN, true, clOption = "--archives"), + new OptionAssigner(appArgs.jars, YARN, true, clOption = "--addJars") + ) + + // more jars + if (appArgs.jars != null && !deployOnCluster) { + for (jar <- appArgs.jars.split(",")) { + childClasspath += jar + } + } + + for (opt <- options) { + if (opt.value != null && deployOnCluster == opt.deployOnCluster && + (clusterManager & opt.clusterManager) != 0) { + if (opt.clOption != null) { + childArgs += (opt.clOption, opt.value) + } else if (opt.sysProp != null) { + sysProps.put(opt.sysProp, opt.value) + } + } + } + + if (deployOnCluster && clusterManager == STANDALONE) { + if (appArgs.supervise) { + childArgs += "--supervise" + } + + childMainClass = "org.apache.spark.deploy.Client" + childArgs += "launch" + childArgs += (appArgs.master, appArgs.primaryResource, appArgs.mainClass) + } + + // args + if (appArgs.childArgs != null) { + if (!deployOnCluster || clusterManager == STANDALONE) { + childArgs ++= appArgs.childArgs + } else if (clusterManager == YARN) { + for (arg <- appArgs.childArgs) { + childArgs += ("--args", arg) + } + } + } + + (childArgs, childClasspath, sysProps, childMainClass) + } + + private def launch(childArgs: ArrayBuffer[String], childClasspath: ArrayBuffer[String], + sysProps: Map[String, String], childMainClass: String, verbose: Boolean = false) { + + if (verbose) { + System.err.println(s"Main class:\n$childMainClass") + System.err.println(s"Arguments:\n${childArgs.mkString("\n")}") + System.err.println(s"System properties:\n${sysProps.mkString("\n")}") + System.err.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") + System.err.println("\n") + } + + val loader = new ExecutorURLClassLoader(new Array[URL](0), + Thread.currentThread.getContextClassLoader) + Thread.currentThread.setContextClassLoader(loader) + + for (jar <- childClasspath) { + addJarToClasspath(jar, loader) + } + + for ((key, value) <- sysProps) { + System.setProperty(key, value) + } + + val mainClass = Class.forName(childMainClass, true, loader) + val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) + mainMethod.invoke(null, childArgs.toArray) + } + + private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) { + val localJarFile = new File(localJar) + if (!localJarFile.exists()) { + printWarning(s"Jar $localJar does not exist, skipping.") + } + + val url = localJarFile.getAbsoluteFile.toURI.toURL + loader.addURL(url) + } +} + +private[spark] class OptionAssigner(val value: String, + val clusterManager: Int, + val deployOnCluster: Boolean, + val clOption: String = null, + val sysProp: String = null +) { } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala new file mode 100644 index 0000000000000..9c8f54ea6f77a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import scala.collection.mutable.ArrayBuffer + +/** + * Parses and encapsulates arguments from the spark-submit script. + */ +private[spark] class SparkSubmitArguments(args: Array[String]) { + var master: String = "local" + var deployMode: String = null + var executorMemory: String = null + var executorCores: String = null + var totalExecutorCores: String = null + var driverMemory: String = null + var driverCores: String = null + var supervise: Boolean = false + var queue: String = null + var numExecutors: String = null + var files: String = null + var archives: String = null + var mainClass: String = null + var primaryResource: String = null + var name: String = null + var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]() + var jars: String = null + var verbose: Boolean = false + + loadEnvVars() + parseOpts(args.toList) + + // Sanity checks + if (args.length == 0) printUsageAndExit(-1) + if (primaryResource == null) SparkSubmit.printErrorAndExit("Must specify a primary resource") + if (mainClass == null) SparkSubmit.printErrorAndExit("Must specify a main class with --class") + + override def toString = { + s"""Parsed arguments: + | master $master + | deployMode $deployMode + | executorMemory $executorMemory + | executorCores $executorCores + | totalExecutorCores $totalExecutorCores + | driverMemory $driverMemory + | drivercores $driverCores + | supervise $supervise + | queue $queue + | numExecutors $numExecutors + | files $files + | archives $archives + | mainClass $mainClass + | primaryResource $primaryResource + | name $name + | childArgs [${childArgs.mkString(" ")}] + | jars $jars + | verbose $verbose + """.stripMargin + } + + private def loadEnvVars() { + Option(System.getenv("MASTER")).map(master = _) + Option(System.getenv("DEPLOY_MODE")).map(deployMode = _) + } + + private def parseOpts(opts: List[String]): Unit = opts match { + case ("--name") :: value :: tail => + name = value + parseOpts(tail) + + case ("--master") :: value :: tail => + master = value + parseOpts(tail) + + case ("--class") :: value :: tail => + mainClass = value + parseOpts(tail) + + case ("--deploy-mode") :: value :: tail => + if (value != "client" && value != "cluster") { + SparkSubmit.printErrorAndExit("--deploy-mode must be either \"client\" or \"cluster\"") + } + deployMode = value + parseOpts(tail) + + case ("--num-executors") :: value :: tail => + numExecutors = value + parseOpts(tail) + + case ("--total-executor-cores") :: value :: tail => + totalExecutorCores = value + parseOpts(tail) + + case ("--executor-cores") :: value :: tail => + executorCores = value + parseOpts(tail) + + case ("--executor-memory") :: value :: tail => + executorMemory = value + parseOpts(tail) + + case ("--driver-memory") :: value :: tail => + driverMemory = value + parseOpts(tail) + + case ("--driver-cores") :: value :: tail => + driverCores = value + parseOpts(tail) + + case ("--supervise") :: tail => + supervise = true + parseOpts(tail) + + case ("--queue") :: value :: tail => + queue = value + parseOpts(tail) + + case ("--files") :: value :: tail => + files = value + parseOpts(tail) + + case ("--archives") :: value :: tail => + archives = value + parseOpts(tail) + + case ("--arg") :: value :: tail => + childArgs += value + parseOpts(tail) + + case ("--jars") :: value :: tail => + jars = value + parseOpts(tail) + + case ("--help" | "-h") :: tail => + printUsageAndExit(0) + + case ("--verbose" | "-v") :: tail => + verbose = true + parseOpts(tail) + + case value :: tail => + if (primaryResource != null) { + val error = s"Found two conflicting resources, $value and $primaryResource." + + " Expecting only one resource." + SparkSubmit.printErrorAndExit(error) + } + primaryResource = value + parseOpts(tail) + + case Nil => + } + + private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + val outStream = SparkSubmit.printStream + if (unknownParam != null) { + outStream.println("Unknown/unsupported param " + unknownParam) + } + outStream.println( + """Usage: spark-submit [options] + |Options: + | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. + | --deploy-mode DEPLOY_MODE Mode to deploy the app in, either 'client' or 'cluster'. + | --class CLASS_NAME Name of your app's main class (required for Java apps). + | --arg ARG Argument to be passed to your application's main class. This + | option can be specified multiple times for multiple args. + | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512M). + | --name NAME The name of your application (Default: 'Spark'). + | --jars JARS A comma-separated list of local jars to include on the + | driver classpath and that SparkContext.addJar will work + | with. Doesn't work on standalone with 'cluster' deploy mode. + | + | Spark standalone with cluster deploy mode only: + | --driver-cores NUM Cores for driver (Default: 1). + | --supervise If given, restarts the driver on failure. + | + | Spark standalone and Mesos only: + | --total-executor-cores NUM Total cores for all executors. + | + | YARN-only: + | --executor-cores NUM Number of cores per executor (Default: 1). + | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). + | --queue QUEUE_NAME The YARN queue to submit to (Default: 'default'). + | --num-executors NUM Number of executors to (Default: 2). + | --files FILES Comma separated list of files to be placed in the working dir + | of each executor. + | --archives ARCHIVES Comma separated list of archives to be extracted into the + | working dir of each executor.""".stripMargin + ) + SparkSubmit.exitFn() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index a730fe1f599af..4433a2ec29be6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -30,7 +30,7 @@ import org.apache.spark.deploy.master.MasterMessages.ElectedLeader * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]] */ private[spark] trait LeaderElectionAgent extends Actor { - //TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring. + // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring. val masterActor: ActorRef } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8fe9b848ba145..aecb069e4202b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -112,11 +112,10 @@ private[spark] class Executor( } } - // Create our ClassLoader and set it on this thread + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - Thread.currentThread.setContextClassLoader(replClassLoader) // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. @@ -276,7 +275,6 @@ private[spark] class Executor( // have left some weird state around depending on when the exception was thrown, but on // the other hand, maybe we could detect that when future tasks fail and exit then. logError("Exception in task ID " + taskId, t) - //System.exit(1) } } finally { // TODO: Unregister shuffle memory only for ResultTask @@ -294,7 +292,7 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): ExecutorURLClassLoader = { - val loader = this.getClass.getClassLoader + val loader = Thread.currentThread().getContextClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 6883a54494598..3e3e18c3537d0 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -42,7 +42,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi } def initialize() { - //Add default properties in case there's no properties file + // Add default properties in case there's no properties file setDefaultProperties(properties) // If spark.metrics.conf is not set, try to get file in class path diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 8fd9c2b87d256..2f7576c53b482 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -48,7 +48,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, channel.socket.setTcpNoDelay(true) channel.socket.setReuseAddress(true) channel.socket.setKeepAlive(true) - /*channel.socket.setReceiveBufferSize(32768) */ + /* channel.socket.setReceiveBufferSize(32768) */ @volatile private var closed = false var onCloseCallback: Connection => Unit = null @@ -206,12 +206,12 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, private class Outbox { val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 + val defaultChunkSize = 65536 var nextMessageToBeUsed = 0 def addMessage(message: Message) { messages.synchronized{ - /*messages += message*/ + /* messages += message */ messages.enqueue(message) logDebug("Added [" + message + "] to outbox for sending to " + "[" + getRemoteConnectionManagerId() + "]") @@ -221,8 +221,8 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, def getChunk(): Option[MessageChunk] = { messages.synchronized { while (!messages.isEmpty) { - /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /*val message = messages(nextMessageToBeUsed)*/ + /* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ + /* val message = messages(nextMessageToBeUsed) */ val message = messages.dequeue val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { @@ -262,7 +262,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, val currentBuffers = new ArrayBuffer[ByteBuffer]() - /*channel.socket.setSendBufferSize(256 * 1024)*/ + /* channel.socket.setSendBufferSize(256 * 1024) */ override def getRemoteAddress() = address @@ -355,7 +355,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } case None => { // changeConnectionKeyInterest(0) - /*key.interestOps(0)*/ + /* key.interestOps(0) */ return false } } @@ -540,10 +540,10 @@ private[spark] class ReceivingConnection( return false } - /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ + /* logDebug("Read " + bytesRead + " bytes for the buffer") */ if (currentChunk.buffer.remaining == 0) { - /*println("Filled buffer at " + System.currentTimeMillis)*/ + /* println("Filled buffer at " + System.currentTimeMillis) */ val bufferMessage = inbox.getMessageForChunk(currentChunk).get if (bufferMessage.isCompletelyReceived) { bufferMessage.flip diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index a75130cba2a2e..6b0a972f0bbe0 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -505,7 +505,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } handleMessageExecutor.execute(runnable) - /*handleMessage(connection, message)*/ + /* handleMessage(connection, message) */ } private def handleClientAuthentication( @@ -733,7 +733,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) - //send security message until going connection has been authenticated + // send security message until going connection has been authenticated connection.send(message) wakeupSelector() @@ -859,14 +859,14 @@ private[spark] object ConnectionManager { None }) - /*testSequentialSending(manager)*/ - /*System.gc()*/ + /* testSequentialSending(manager) */ + /* System.gc() */ - /*testParallelSending(manager)*/ - /*System.gc()*/ + /* testParallelSending(manager) */ + /* System.gc() */ - /*testParallelDecreasingSending(manager)*/ - /*System.gc()*/ + /* testParallelDecreasingSending(manager) */ + /* System.gc() */ testContinuousSending(manager) System.gc() @@ -948,7 +948,7 @@ private[spark] object ConnectionManager { val ms = finishTime - startTime val tput = mb * 1000.0 / ms println("--------------------------") - /*println("Started at " + startTime + ", finished at " + finishTime) */ + /* println("Started at " + startTime + ", finished at " + finishTime) */ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala index 35f64134b073a..9d9b9dbdd5331 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala @@ -47,8 +47,8 @@ private[spark] object ConnectionManagerTest extends Logging{ val slaves = slavesFile.mkString.split("\n") slavesFile.close() - /*println("Slaves")*/ - /*slaves.foreach(println)*/ + /* println("Slaves") */ + /* slaves.foreach(println) */ val tasknum = if (args.length > 2) args(2).toInt else slaves.length val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 val count = if (args.length > 4) args(4).toInt else 3 diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 3c09a713c6fe0..2b41c403b2e0a 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -27,7 +27,7 @@ private[spark] object ReceiverTest { println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ + /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */ val buffer = ByteBuffer.wrap("response".getBytes) Some(Message.createBufferMessage(buffer, msg.id)) }) diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index aac2c24a46faa..14c094c6177d5 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -50,7 +50,7 @@ private[spark] object SenderTest { (0 until count).foreach(i => { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis - /*println("Started timer at " + startTime)*/ + /* println("Started timer at " + startTime) */ val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index f9082ffb9141a..4164e81d3a8ae 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -32,7 +32,7 @@ private[spark] class FileHeader ( buf.writeInt(fileLen) buf.writeInt(blockId.name.length) blockId.name.foreach((x: Char) => buf.writeByte(x)) - //padding the rest of header + // padding the rest of header if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 77c558ac46f6f..ef3d24d746829 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -84,7 +84,7 @@ class DAGScheduler( private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] - private[scheduler] val stageIdToActiveJob = new HashMap[Int, ActiveJob] + private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -536,7 +536,7 @@ class DAGScheduler( listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) runLocally(job) } else { - stageIdToActiveJob(jobId) = job + jobIdToActiveJob(jobId) = job activeJobs += job resultStageToJob(finalStage) = job listenerBus.post( @@ -559,7 +559,7 @@ class DAGScheduler( // Cancel all running jobs. runningStages.map(_.jobId).foreach(handleJobCancellation) activeJobs.clear() // These should already be empty by this point, - stageIdToActiveJob.clear() // but just in case we lost track of some jobs... + jobIdToActiveJob.clear() // but just in case we lost track of some jobs... case ExecutorAdded(execId, host) => handleExecutorAdded(execId, host) @@ -569,7 +569,6 @@ class DAGScheduler( case BeginEvent(task, taskInfo) => for ( - job <- stageIdToActiveJob.get(task.stageId); stage <- stageIdToStage.get(task.stageId); stageInfo <- stageToInfos.get(stage) ) { @@ -697,7 +696,7 @@ class DAGScheduler( private def activeJobForStage(stage: Stage): Option[Int] = { if (stageIdToJobIds.contains(stage.id)) { val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted - jobsThatUseStage.find(stageIdToActiveJob.contains) + jobsThatUseStage.find(jobIdToActiveJob.contains) } else { None } @@ -750,10 +749,10 @@ class DAGScheduler( } } - val properties = if (stageIdToActiveJob.contains(jobId)) { - stageIdToActiveJob(stage.jobId).properties + val properties = if (jobIdToActiveJob.contains(jobId)) { + jobIdToActiveJob(stage.jobId).properties } else { - //this stage will be assigned to "default" pool + // this stage will be assigned to "default" pool null } @@ -827,7 +826,7 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - stageIdToActiveJob -= stage.jobId + jobIdToActiveJob -= stage.jobId activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) @@ -986,11 +985,11 @@ class DAGScheduler( val independentStages = removeJobAndIndependentStages(jobId) independentStages.foreach(taskScheduler.cancelTasks) val error = new SparkException("Job %d cancelled".format(jobId)) - val job = stageIdToActiveJob(jobId) + val job = jobIdToActiveJob(jobId) job.listener.jobFailed(error) jobIdToStageIds -= jobId activeJobs -= job - stageIdToActiveJob -= jobId + jobIdToActiveJob -= jobId listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id))) } } @@ -1011,7 +1010,7 @@ class DAGScheduler( val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) jobIdToStageIdsRemove(job.jobId) - stageIdToActiveJob -= resultStage.jobId + jobIdToActiveJob -= resultStage.jobId activeJobs -= job resultStageToJob -= resultStage listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, failedStage.id))) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 990e01a3e7959..7bfc30b4208a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -172,7 +172,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A properties += ((key, value)) } } - //TODO (prashant) send conf instead of properties + // TODO (prashant) send conf instead of properties driverActor = actorSystem.actorOf( Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 6b6d814c1fe92..926e71573be32 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -107,7 +107,8 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser kryo.readClassAndObject(input).asInstanceOf[T] } catch { // DeserializationStream uses the EOF exception to indicate stopping condition. - case _: KryoException => throw new EOFException + case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") => + throw new EOFException } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index bcfc39146a61e..2fbbda5b76c74 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -284,7 +284,7 @@ object BlockFetcherIterator { } } catch { case x: InterruptedException => logInfo("Copier Interrupted") - //case _ => throw new SparkException("Exception Throw in Shuffle Copier") + // case _ => throw new SparkException("Exception Throw in Shuffle Copier") } } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 6e1736f6fbc23..e1a1f209c9282 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -18,13 +18,14 @@ package org.apache.spark.ui import java.net.{InetSocketAddress, URL} +import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.annotation.tailrec import scala.util.{Failure, Success, Try} import scala.xml.Node -import org.eclipse.jetty.server.{DispatcherType, Server} +import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.util.thread.QueuedThreadPool diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala index f3c93d4214ad0..70d62b66a4829 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -25,7 +25,7 @@ import org.apache.spark.scheduler.Schedulable import org.apache.spark.ui.Page._ import org.apache.spark.ui.UIUtils -/** Page showing list of all ongoing and recently finished stages and pools*/ +/** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class IndexPage(parent: JobProgressUI) { private val appName = parent.appName private val basePath = parent.basePath diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index 4d8b01dbe6e1b..a7b24ff695214 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -84,7 +84,7 @@ private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListe override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { val rddInfo = stageSubmitted.stageInfo.rddInfo - _rddInfoMap(rddInfo.id) = rddInfo + _rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo) } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index a8d20ee332355..cdbbc65292188 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -112,7 +112,7 @@ private[spark] object ClosureCleaner extends Logging { accessedFields(cls) = Set[String]() for (cls <- func.getClass :: innerClasses) getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0) - //logInfo("accessedFields: " + accessedFields) + // logInfo("accessedFields: " + accessedFields) val inInterpreter = { try { @@ -139,13 +139,13 @@ private[spark] object ClosureCleaner extends Logging { val field = cls.getDeclaredField(fieldName) field.setAccessible(true) val value = field.get(obj) - //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); + // logInfo("1: Setting " + fieldName + " on " + cls + " to " + value); field.set(outer, value) } } if (outer != null) { - //logInfo("2: Setting $outer on " + func.getClass + " to " + outer); + // logInfo("2: Setting $outer on " + func.getClass + " to " + outer); val field = func.getClass.getDeclaredField("$outer") field.setAccessible(true) field.set(func, outer) @@ -153,7 +153,7 @@ private[spark] object ClosureCleaner extends Logging { } private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { - //logInfo("Creating a " + cls + " with outer = " + outer) + // logInfo("Creating a " + cls + " with outer = " + outer) if (!inInterpreter) { // This is a bona fide closure class, whose constructor has no effects // other than to set its fields, so use its constructor @@ -170,7 +170,7 @@ private[spark] object ClosureCleaner extends Logging { val newCtor = rf.newConstructorForSerialization(cls, parentCtor) val obj = newCtor.newInstance().asInstanceOf[AnyRef] if (outer != null) { - //logInfo("3: Setting $outer on " + cls + " to " + outer); + // logInfo("3: Setting $outer on " + cls + " to " + outer); val field = cls.getDeclaredField("$outer") field.setAccessible(true) field.set(obj, outer) diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala index c539d2f708f95..4188a869c13da 100644 --- a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala +++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala @@ -49,7 +49,7 @@ private[akka] class IndestructibleActorSystemImpl( if (isFatalError(cause) && !settings.JvmExitOnFatalError) { log.error(cause, "Uncaught fatal error from thread [{}] not shutting down " + "ActorSystem [{}] tolerating and continuing.... ", thread.getName, name) - //shutdown() //TODO make it configurable + // shutdown() //TODO make it configurable } else { fallbackHandler.uncaughtException(thread, cause) } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 346f2b7856791..d9a6af61872d1 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -195,7 +195,7 @@ private[spark] object JsonProtocol { taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) val updatedBlocks = taskMetrics.updatedBlocks.map { blocks => JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> blockIdToJson(id)) ~ + ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) }.getOrElse(JNothing) @@ -284,35 +284,6 @@ private[spark] object JsonProtocol { ("Replication" -> storageLevel.replication) } - def blockIdToJson(blockId: BlockId): JValue = { - val blockType = Utils.getFormattedClassName(blockId) - val json: JObject = blockId match { - case rddBlockId: RDDBlockId => - ("RDD ID" -> rddBlockId.rddId) ~ - ("Split Index" -> rddBlockId.splitIndex) - case shuffleBlockId: ShuffleBlockId => - ("Shuffle ID" -> shuffleBlockId.shuffleId) ~ - ("Map ID" -> shuffleBlockId.mapId) ~ - ("Reduce ID" -> shuffleBlockId.reduceId) - case broadcastBlockId: BroadcastBlockId => - "Broadcast ID" -> broadcastBlockId.broadcastId - case broadcastHelperBlockId: BroadcastHelperBlockId => - ("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~ - ("Helper Type" -> broadcastHelperBlockId.hType) - case taskResultBlockId: TaskResultBlockId => - "Task ID" -> taskResultBlockId.taskId - case streamBlockId: StreamBlockId => - ("Stream ID" -> streamBlockId.streamId) ~ - ("Unique ID" -> streamBlockId.uniqueId) - case tempBlockId: TempBlockId => - val uuid = UUIDToJson(tempBlockId.id) - "Temp ID" -> uuid - case testBlockId: TestBlockId => - "Test ID" -> testBlockId.id - } - ("Type" -> blockType) ~ json - } - def blockStatusToJson(blockStatus: BlockStatus): JValue = { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ @@ -513,7 +484,7 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value => value.extract[List[JValue]].map { block => - val id = blockIdFromJson(block \ "Block ID") + val id = BlockId((block \ "Block ID").extract[String]) val status = blockStatusFromJson(block \ "Status") (id, status) } @@ -616,50 +587,6 @@ private[spark] object JsonProtocol { StorageLevel(useDisk, useMemory, deserialized, replication) } - def blockIdFromJson(json: JValue): BlockId = { - val rddBlockId = Utils.getFormattedClassName(RDDBlockId) - val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) - val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) - val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId) - val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) - val streamBlockId = Utils.getFormattedClassName(StreamBlockId) - val tempBlockId = Utils.getFormattedClassName(TempBlockId) - val testBlockId = Utils.getFormattedClassName(TestBlockId) - - (json \ "Type").extract[String] match { - case `rddBlockId` => - val rddId = (json \ "RDD ID").extract[Int] - val splitIndex = (json \ "Split Index").extract[Int] - new RDDBlockId(rddId, splitIndex) - case `shuffleBlockId` => - val shuffleId = (json \ "Shuffle ID").extract[Int] - val mapId = (json \ "Map ID").extract[Int] - val reduceId = (json \ "Reduce ID").extract[Int] - new ShuffleBlockId(shuffleId, mapId, reduceId) - case `broadcastBlockId` => - val broadcastId = (json \ "Broadcast ID").extract[Long] - new BroadcastBlockId(broadcastId) - case `broadcastHelperBlockId` => - val broadcastBlockId = - blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId] - val hType = (json \ "Helper Type").extract[String] - new BroadcastHelperBlockId(broadcastBlockId, hType) - case `taskResultBlockId` => - val taskId = (json \ "Task ID").extract[Long] - new TaskResultBlockId(taskId) - case `streamBlockId` => - val streamId = (json \ "Stream ID").extract[Int] - val uniqueId = (json \ "Unique ID").extract[Long] - new StreamBlockId(streamId, uniqueId) - case `tempBlockId` => - val tempId = UUIDFromJson(json \ "Temp ID") - new TempBlockId(tempId) - case `testBlockId` => - val testId = (json \ "Test ID").extract[String] - new TestBlockId(testId) - } - } - def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index 2c1a6f8fd0a44..a6b39247a54ca 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -24,8 +24,8 @@ package org.apache.spark.util * @param _1 Element 1 of this MutablePair * @param _2 Element 2 of this MutablePair */ -case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, - @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] +case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T1, + @specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T2] (var _1: T1, var _2: T2) extends Product2[T1, T2] { diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 6c73ea6949dd2..4e7c34e6d1ada 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -66,7 +66,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte test ("add value to collection accumulators") { val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded + for (nThreads <- List(1, 10)) { // test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) val d = sc.parallelize(1 to maxI) @@ -83,7 +83,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte test ("value not readable in tasks") { val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded + for (nThreads <- List(1, 10)) { // test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) val d = sc.parallelize(1 to maxI) @@ -124,7 +124,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte test ("localValue readable in tasks") { val maxI = 1000 - for (nThreads <- List(1, 10)) { //test single & multi-threaded + for (nThreads <- List(1, 10)) { // test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") val acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]()) val groupedInts = (1 to (maxI/20)).map {x => (20 * (x - 1) to 20 * x).toSet} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d2e29f20f0b08..d2555b7c052c1 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -432,7 +432,6 @@ object CheckpointSuite { // This is a custom cogroup function that does not use mapValues like // the PairRDDFunctions.cogroup() def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { - //println("First = " + first + ", second = " + second) new CoGroupedRDD[K]( Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 996db70809320..7c30626a0c421 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -146,7 +146,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array")) // We can't catch all usages of arrays, since they might occur inside other collections: - //assert(fails { arrPairs.distinct() }) + // assert(fails { arrPairs.distinct() }) assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala new file mode 100644 index 0000000000000..4e489cd9b66a6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.{OutputStream, PrintStream} + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.deploy.SparkSubmit._ + + +class SparkSubmitSuite extends FunSuite with ShouldMatchers { + + val noOpOutputStream = new OutputStream { + def write(b: Int) = {} + } + + /** Simple PrintStream that reads data into a buffer */ + class BufferPrintStream extends PrintStream(noOpOutputStream) { + var lineBuffer = ArrayBuffer[String]() + override def println(line: String) { + lineBuffer += line + } + } + + /** Returns true if the script exits and the given search string is printed. */ + def testPrematureExit(input: Array[String], searchString: String): Boolean = { + val printStream = new BufferPrintStream() + SparkSubmit.printStream = printStream + + @volatile var exitedCleanly = false + SparkSubmit.exitFn = () => exitedCleanly = true + + val thread = new Thread { + override def run() = try { + SparkSubmit.main(input) + } catch { + // If exceptions occur after the "exit" has happened, fine to ignore them. + // These represent code paths not reachable during normal execution. + case e: Exception => if (!exitedCleanly) throw e + } + } + thread.start() + thread.join() + printStream.lineBuffer.find(s => s.contains(searchString)).size > 0 + } + + test("prints usage on empty input") { + testPrematureExit(Array[String](), "Usage: spark-submit") should be (true) + } + + test("prints usage with only --help") { + testPrematureExit(Array("--help"), "Usage: spark-submit") should be (true) + } + + test("handles multiple binary definitions") { + val adjacentJars = Array("foo.jar", "bar.jar") + testPrematureExit(adjacentJars, "error: Found two conflicting resources") should be (true) + + val nonAdjacentJars = + Array("foo.jar", "--master", "123", "--class", "abc", "bar.jar") + testPrematureExit(nonAdjacentJars, "error: Found two conflicting resources") should be (true) + } + + test("handle binary specified but not class") { + testPrematureExit(Array("foo.jar"), "must specify a main class") + } + + test("handles YARN cluster mode") { + val clArgs = Array("thejar.jar", "--deploy-mode", "cluster", + "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", + "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", + "--arg", "arg1", "--arg", "arg2", "--driver-memory", "4g", + "--queue", "thequeue", "--files", "file1.txt,file2.txt", + "--archives", "archive1.txt,archive2.txt", "--num-executors", "6") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val childArgsStr = childArgs.mkString(" ") + childArgsStr should include ("--jar thejar.jar") + childArgsStr should include ("--class org.SomeClass") + childArgsStr should include ("--addJars one.jar,two.jar,three.jar") + childArgsStr should include ("--executor-memory 5g") + childArgsStr should include ("--driver-memory 4g") + childArgsStr should include ("--executor-cores 5") + childArgsStr should include ("--args arg1 --args arg2") + childArgsStr should include ("--queue thequeue") + childArgsStr should include ("--files file1.txt,file2.txt") + childArgsStr should include ("--archives archive1.txt,archive2.txt") + childArgsStr should include ("--num-executors 6") + mainClass should be ("org.apache.spark.deploy.yarn.Client") + classpath should have length (0) + sysProps should have size (0) + } + + test("handles YARN client mode") { + val clArgs = Array("thejar.jar", "--deploy-mode", "client", + "--master", "yarn", "--executor-memory", "5g", "--executor-cores", "5", + "--class", "org.SomeClass", "--jars", "one.jar,two.jar,three.jar", + "--arg", "arg1", "--arg", "arg2", "--driver-memory", "4g", + "--queue", "thequeue", "--files", "file1.txt,file2.txt", + "--archives", "archive1.txt,archive2.txt", "--num-executors", "6") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + childArgs.mkString(" ") should be ("arg1 arg2") + mainClass should be ("org.SomeClass") + classpath should contain ("thejar.jar") + classpath should contain ("one.jar") + classpath should contain ("two.jar") + classpath should contain ("three.jar") + sysProps("spark.executor.memory") should be ("5g") + sysProps("spark.executor.cores") should be ("5") + sysProps("spark.yarn.queue") should be ("thequeue") + sysProps("spark.yarn.dist.files") should be ("file1.txt,file2.txt") + sysProps("spark.yarn.dist.archives") should be ("archive1.txt,archive2.txt") + sysProps("spark.executor.instances") should be ("6") + } + + test("handles standalone cluster mode") { + val clArgs = Array("thejar.jar", "--deploy-mode", "cluster", + "--master", "spark://h:p", "--class", "org.SomeClass", "--arg", "arg1", "--arg", "arg2", + "--supervise", "--driver-memory", "4g", "--driver-cores", "5") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val childArgsStr = childArgs.mkString(" ") + print("child args: " + childArgsStr) + childArgsStr.startsWith("--memory 4g --cores 5 --supervise") should be (true) + childArgsStr should include ("launch spark://h:p thejar.jar org.SomeClass arg1 arg2") + mainClass should be ("org.apache.spark.deploy.Client") + classpath should have length (0) + sysProps should have size (0) + } + + test("handles standalone client mode") { + val clArgs = Array("thejar.jar", "--deploy-mode", "client", + "--master", "spark://h:p", "--executor-memory", "5g", "--total-executor-cores", "5", + "--class", "org.SomeClass", "--arg", "arg1", "--arg", "arg2", + "--driver-memory", "4g") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + childArgs.mkString(" ") should be ("arg1 arg2") + mainClass should be ("org.SomeClass") + classpath should contain ("thejar.jar") + sysProps("spark.executor.memory") should be ("5g") + sysProps("spark.cores.max") should be ("5") + } + + test("handles mesos client mode") { + val clArgs = Array("thejar.jar", "--deploy-mode", "client", + "--master", "mesos://h:p", "--executor-memory", "5g", "--total-executor-cores", "5", + "--class", "org.SomeClass", "--arg", "arg1", "--arg", "arg2", + "--driver-memory", "4g") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + childArgs.mkString(" ") should be ("arg1 arg2") + mainClass should be ("org.SomeClass") + classpath should contain ("thejar.jar") + sysProps("spark.executor.memory") should be ("5g") + sysProps("spark.cores.max") should be ("5") + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c97543f57d8f3..ce567b0cde85d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -428,7 +428,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(scheduler.pendingTasks.isEmpty) assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) - assert(scheduler.stageIdToActiveJob.isEmpty) + assert(scheduler.jobIdToActiveJob.isEmpty) assert(scheduler.jobIdToStageIds.isEmpty) assert(scheduler.stageIdToJobIds.isEmpty) assert(scheduler.stageIdToStage.isEmpty) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index a25ce35736146..7c843772bc2e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -111,7 +111,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) - //just to make sure some of the tasks take a noticeable amount of time + // just to make sure some of the tasks take a noticeable amount of time val w = {i:Int => if (i == 0) Thread.sleep(100) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 67c0a434c9b52..40c29014c4b59 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -112,7 +112,6 @@ class JsonProtocolSuite extends FunSuite { testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark")) testBlockId(TaskResultBlockId(1L)) testBlockId(StreamBlockId(1, 2L)) - testBlockId(TempBlockId(UUID.randomUUID())) } @@ -168,8 +167,8 @@ class JsonProtocolSuite extends FunSuite { } private def testBlockId(blockId: BlockId) { - val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId)) - blockId == newBlockId + val newBlockId = BlockId(blockId.toString) + assert(blockId === newBlockId) } @@ -180,90 +179,90 @@ class JsonProtocolSuite extends FunSuite { private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { (event1, event2) match { case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) => - assert(e1.properties == e2.properties) + assert(e1.properties === e2.properties) assertEquals(e1.stageInfo, e2.stageInfo) case (e1: SparkListenerStageCompleted, e2: SparkListenerStageCompleted) => assertEquals(e1.stageInfo, e2.stageInfo) case (e1: SparkListenerTaskStart, e2: SparkListenerTaskStart) => - assert(e1.stageId == e2.stageId) + assert(e1.stageId === e2.stageId) assertEquals(e1.taskInfo, e2.taskInfo) case (e1: SparkListenerTaskGettingResult, e2: SparkListenerTaskGettingResult) => assertEquals(e1.taskInfo, e2.taskInfo) case (e1: SparkListenerTaskEnd, e2: SparkListenerTaskEnd) => - assert(e1.stageId == e2.stageId) - assert(e1.taskType == e2.taskType) + assert(e1.stageId === e2.stageId) + assert(e1.taskType === e2.taskType) assertEquals(e1.reason, e2.reason) assertEquals(e1.taskInfo, e2.taskInfo) assertEquals(e1.taskMetrics, e2.taskMetrics) case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) => - assert(e1.jobId == e2.jobId) - assert(e1.properties == e2.properties) - assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 == i2)) + assert(e1.jobId === e2.jobId) + assert(e1.properties === e2.properties) + assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2)) case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) => - assert(e1.jobId == e2.jobId) + assert(e1.jobId === e2.jobId) assertEquals(e1.jobResult, e2.jobResult) case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) => - assert(e1.maxMem == e2.maxMem) + assert(e1.maxMem === e2.maxMem) assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) => assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) => - assert(e1.rddId == e2.rddId) + assert(e1.rddId === e2.rddId) case (SparkListenerShutdown, SparkListenerShutdown) => case _ => fail("Events don't match in types!") } } private def assertEquals(info1: StageInfo, info2: StageInfo) { - assert(info1.stageId == info2.stageId) - assert(info1.name == info2.name) - assert(info1.numTasks == info2.numTasks) - assert(info1.submissionTime == info2.submissionTime) - assert(info1.completionTime == info2.completionTime) - assert(info1.emittedTaskSizeWarning == info2.emittedTaskSizeWarning) + assert(info1.stageId === info2.stageId) + assert(info1.name === info2.name) + assert(info1.numTasks === info2.numTasks) + assert(info1.submissionTime === info2.submissionTime) + assert(info1.completionTime === info2.completionTime) + assert(info1.emittedTaskSizeWarning === info2.emittedTaskSizeWarning) assertEquals(info1.rddInfo, info2.rddInfo) } private def assertEquals(info1: RDDInfo, info2: RDDInfo) { - assert(info1.id == info2.id) - assert(info1.name == info2.name) - assert(info1.numPartitions == info2.numPartitions) - assert(info1.numCachedPartitions == info2.numCachedPartitions) - assert(info1.memSize == info2.memSize) - assert(info1.diskSize == info2.diskSize) + assert(info1.id === info2.id) + assert(info1.name === info2.name) + assert(info1.numPartitions === info2.numPartitions) + assert(info1.numCachedPartitions === info2.numCachedPartitions) + assert(info1.memSize === info2.memSize) + assert(info1.diskSize === info2.diskSize) assertEquals(info1.storageLevel, info2.storageLevel) } private def assertEquals(level1: StorageLevel, level2: StorageLevel) { - assert(level1.useDisk == level2.useDisk) - assert(level1.useMemory == level2.useMemory) - assert(level1.deserialized == level2.deserialized) - assert(level1.replication == level2.replication) + assert(level1.useDisk === level2.useDisk) + assert(level1.useMemory === level2.useMemory) + assert(level1.deserialized === level2.deserialized) + assert(level1.replication === level2.replication) } private def assertEquals(info1: TaskInfo, info2: TaskInfo) { - assert(info1.taskId == info2.taskId) - assert(info1.index == info2.index) - assert(info1.launchTime == info2.launchTime) - assert(info1.executorId == info2.executorId) - assert(info1.host == info2.host) - assert(info1.taskLocality == info2.taskLocality) - assert(info1.gettingResultTime == info2.gettingResultTime) - assert(info1.finishTime == info2.finishTime) - assert(info1.failed == info2.failed) - assert(info1.serializedSize == info2.serializedSize) + assert(info1.taskId === info2.taskId) + assert(info1.index === info2.index) + assert(info1.launchTime === info2.launchTime) + assert(info1.executorId === info2.executorId) + assert(info1.host === info2.host) + assert(info1.taskLocality === info2.taskLocality) + assert(info1.gettingResultTime === info2.gettingResultTime) + assert(info1.finishTime === info2.finishTime) + assert(info1.failed === info2.failed) + assert(info1.serializedSize === info2.serializedSize) } private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { - assert(metrics1.hostname == metrics2.hostname) - assert(metrics1.executorDeserializeTime == metrics2.executorDeserializeTime) - assert(metrics1.resultSize == metrics2.resultSize) - assert(metrics1.jvmGCTime == metrics2.jvmGCTime) - assert(metrics1.resultSerializationTime == metrics2.resultSerializationTime) - assert(metrics1.memoryBytesSpilled == metrics2.memoryBytesSpilled) - assert(metrics1.diskBytesSpilled == metrics2.diskBytesSpilled) + assert(metrics1.hostname === metrics2.hostname) + assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) + assert(metrics1.resultSize === metrics2.resultSize) + assert(metrics1.jvmGCTime === metrics2.jvmGCTime) + assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime) + assert(metrics1.memoryBytesSpilled === metrics2.memoryBytesSpilled) + assert(metrics1.diskBytesSpilled === metrics2.diskBytesSpilled) assertOptionEquals( metrics1.shuffleReadMetrics, metrics2.shuffleReadMetrics, assertShuffleReadEquals) assertOptionEquals( @@ -272,31 +271,31 @@ class JsonProtocolSuite extends FunSuite { } private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) { - assert(metrics1.shuffleFinishTime == metrics2.shuffleFinishTime) - assert(metrics1.totalBlocksFetched == metrics2.totalBlocksFetched) - assert(metrics1.remoteBlocksFetched == metrics2.remoteBlocksFetched) - assert(metrics1.localBlocksFetched == metrics2.localBlocksFetched) - assert(metrics1.fetchWaitTime == metrics2.fetchWaitTime) - assert(metrics1.remoteBytesRead == metrics2.remoteBytesRead) + assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime) + assert(metrics1.totalBlocksFetched === metrics2.totalBlocksFetched) + assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched) + assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched) + assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime) + assert(metrics1.remoteBytesRead === metrics2.remoteBytesRead) } private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) { - assert(metrics1.shuffleBytesWritten == metrics2.shuffleBytesWritten) - assert(metrics1.shuffleWriteTime == metrics2.shuffleWriteTime) + assert(metrics1.shuffleBytesWritten === metrics2.shuffleBytesWritten) + assert(metrics1.shuffleWriteTime === metrics2.shuffleWriteTime) } private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) { - assert(bm1.executorId == bm2.executorId) - assert(bm1.host == bm2.host) - assert(bm1.port == bm2.port) - assert(bm1.nettyPort == bm2.nettyPort) + assert(bm1.executorId === bm2.executorId) + assert(bm1.host === bm2.host) + assert(bm1.port === bm2.port) + assert(bm1.nettyPort === bm2.nettyPort) } private def assertEquals(result1: JobResult, result2: JobResult) { (result1, result2) match { case (JobSucceeded, JobSucceeded) => case (r1: JobFailed, r2: JobFailed) => - assert(r1.failedStageId == r2.failedStageId) + assert(r1.failedStageId === r2.failedStageId) assertEquals(r1.exception, r2.exception) case _ => fail("Job results don't match in types!") } @@ -307,13 +306,13 @@ class JsonProtocolSuite extends FunSuite { case (Success, Success) => case (Resubmitted, Resubmitted) => case (r1: FetchFailed, r2: FetchFailed) => - assert(r1.shuffleId == r2.shuffleId) - assert(r1.mapId == r2.mapId) - assert(r1.reduceId == r2.reduceId) + assert(r1.shuffleId === r2.shuffleId) + assert(r1.mapId === r2.mapId) + assert(r1.reduceId === r2.reduceId) assertEquals(r1.bmAddress, r2.bmAddress) case (r1: ExceptionFailure, r2: ExceptionFailure) => - assert(r1.className == r2.className) - assert(r1.description == r2.description) + assert(r1.className === r2.className) + assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => @@ -329,13 +328,13 @@ class JsonProtocolSuite extends FunSuite { details2: Map[String, Seq[(String, String)]]) { details1.zip(details2).foreach { case ((key1, values1: Seq[(String, String)]), (key2, values2: Seq[(String, String)])) => - assert(key1 == key2) - values1.zip(values2).foreach { case (v1, v2) => assert(v1 == v2) } + assert(key1 === key2) + values1.zip(values2).foreach { case (v1, v2) => assert(v1 === v2) } } } private def assertEquals(exception1: Exception, exception2: Exception) { - assert(exception1.getMessage == exception2.getMessage) + assert(exception1.getMessage === exception2.getMessage) assertSeqEquals( exception1.getStackTrace, exception2.getStackTrace, @@ -344,11 +343,11 @@ class JsonProtocolSuite extends FunSuite { private def assertJsonStringEquals(json1: String, json2: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - formatJsonString(json1) == formatJsonString(json2) + formatJsonString(json1) === formatJsonString(json2) } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { - assert(seq1.length == seq2.length) + assert(seq1.length === seq2.length) seq1.zip(seq2).foreach { case (t1, t2) => assertEquals(t1, t2) } @@ -389,11 +388,11 @@ class JsonProtocolSuite extends FunSuite { } private def assertBlockEquals(b1: (BlockId, BlockStatus), b2: (BlockId, BlockStatus)) { - assert(b1 == b2) + assert(b1 === b2) } private def assertStackTraceElementEquals(ste1: StackTraceElement, ste2: StackTraceElement) { - assert(ste1 == ste2) + assert(ste1 === ste2) } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index eb8f5915605de..616214fb5e3a6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -39,7 +39,7 @@ class UtilsSuite extends FunSuite { } test("copyStream") { - //input array initialization + // input array initialization val bytes = Array.ofDim[Byte](9000) Random.nextBytes(bytes) diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index 2437a98672177..38becda0eae92 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -4,7 +4,7 @@ run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.3 \ +$ SCALA_VERSION=2.10.4 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 52c367d9b030d..fa2f02dfecc75 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,7 +35,7 @@ RELEASE_KEY = "9E4FE3AF" RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/" RELEASE_VERSION = "1.0.0" -SCALA_VERSION = "2.10.3" +SCALA_VERSION = "2.10.4" SCALA_BINARY_VERSION = "2.10" ## diff --git a/dev/run-tests b/dev/run-tests index 6f115d2abd5b0..a6fcc40a5ba6e 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -39,18 +39,17 @@ JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..* echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" - dev/check-license echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" -sbt/sbt clean scalastyle +dev/scalastyle echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" -sbt/sbt assembly test +sbt/sbt assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" echo "=========================================================================" echo "Running PySpark tests" @@ -64,5 +63,5 @@ echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore -sbt/sbt mima-report-binary-issues +sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" diff --git a/dev/scalastyle b/dev/scalastyle new file mode 100755 index 0000000000000..5a18f4d672825 --- /dev/null +++ b/dev/scalastyle @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +sbt/sbt clean scalastyle > scalastyle.txt +ERRORS=$(cat scalastyle.txt | grep -e "error file") +if test ! -z "$ERRORS"; then + echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" + exit 1 +else + echo -e "Scalastyle checks passed.\n" +fi diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index e543db6143e4d..5956d59130fbf 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update # install a few other useful packages plus Open Jdk 7 RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server -ENV SCALA_VERSION 2.10.3 +ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docs/_config.yml b/docs/_config.yml index aa5a5adbc1743..d585b8c5ea763 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -6,7 +6,7 @@ markdown: kramdown SPARK_VERSION: 1.0.0-SNAPSHOT SPARK_VERSION_SHORT: 1.0.0 SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.3" +SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.13.0 SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index a555a7b5023e3..b69e3416fb322 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -50,6 +50,50 @@ The system currently supports three cluster managers: In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone cluster on Amazon EC2. +# Launching Applications + +The recommended way to launch a compiled Spark application is through the spark-submit script (located in the +bin directory), which takes care of setting up the classpath with Spark and its dependencies, as well as +provides a layer over the different cluster managers and deploy modes that Spark supports. It's usage is + + spark-submit `` `` + +Where options are any of: + +- **\--class** - The main class to run. +- **\--master** - The URL of the cluster manager master, e.g. spark://host:port, mesos://host:port, yarn, + or local. +- **\--deploy-mode** - "client" to run the driver in the client process or "cluster" to run the driver in + a process on the cluster. For Mesos, only "client" is supported. +- **\--executor-memory** - Memory per executor (e.g. 1000M, 2G). +- **\--executor-cores** - Number of cores per executor. (Default: 2) +- **\--driver-memory** - Memory for driver (e.g. 1000M, 2G) +- **\--name** - Name of the application. +- **\--arg** - Argument to be passed to the application's main class. This option can be specified + multiple times to pass multiple arguments. +- **\--jars** - A comma-separated list of local jars to include on the driver classpath and that + SparkContext.addJar will work with. Doesn't work on standalone with 'cluster' deploy mode. + +The following currently only work for Spark standalone with cluster deploy mode: + +- **\--driver-cores** - Cores for driver (Default: 1). +- **\--supervise** - If given, restarts the driver on failure. + +The following only works for Spark standalone and Mesos only: + +- **\--total-executor-cores** - Total cores for all executors. + +The following currently only work for YARN: + +- **\--queue** - The YARN queue to place the application in. +- **\--files** - Comma separated list of files to be placed in the working dir of each executor. +- **\--archives** - Comma separated list of archives to be extracted into the working dir of each + executor. +- **\--num-executors** - Number of executors (Default: 2). + +The master and deploy mode can also be set with the MASTER and DEPLOY_MODE environment variables. +Values for these options passed via command line will override the environment variables. + # Shipping Code to the Cluster The recommended way to ship your code to the cluster is to pass it through SparkContext's constructor, @@ -102,6 +146,12 @@ The following table summarizes terms you'll see used to refer to cluster concept Cluster manager An external service for acquiring resources on the cluster (e.g. standalone manager, Mesos, YARN) + + Deploy mode + Distinguishes where the driver process runs. In "cluster" mode, the framework launches + the driver inside of the cluster. In "client" mode, the submitter launches the driver + outside of the cluster. + Worker node Any node that can run application code in the cluster diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2e9dec4856ee9..982514391ac00 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -48,10 +48,12 @@ System Properties: Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the Hadoop cluster. These configs are used to connect to the cluster, write to the dfs, and connect to the YARN ResourceManager. -There are two scheduler modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. +There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". +The spark-submit script described in the [cluster mode overview](cluster-overview.html) provides the most straightforward way to submit a compiled Spark application to YARN in either deploy mode. For info on the lower-level invocations it uses, read ahead. For running spark-shell against YARN, skip down to the yarn-client section. + ## Launching a Spark application with yarn-cluster mode. The command to launch the Spark application on the cluster is as follows: @@ -59,7 +61,7 @@ The command to launch the Spark application on the cluster is as follows: SPARK_JAR= ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar \ --class \ - --args \ + --arg \ --num-executors \ --driver-memory \ --executor-memory \ @@ -70,7 +72,7 @@ The command to launch the Spark application on the cluster is as follows: --files \ --archives -For example: +To pass multiple arguments the "arg" option can be specified multiple times. For example: # Build the Spark assembly JAR and the Spark examples JAR $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly @@ -83,7 +85,8 @@ For example: ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ --class org.apache.spark.examples.SparkPi \ - --args yarn-cluster \ + --arg yarn-cluster \ + --arg 5 \ --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ @@ -121,7 +124,7 @@ or MASTER=yarn-client ./bin/spark-shell -## Viewing logs +# Viewing logs In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the yarn.log-aggregation-enable config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 51fb3a4f7f8c5..7e4eea323aa63 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -146,10 +146,13 @@ automatically set MASTER from the `SPARK_MASTER_IP` and `SPARK_MASTER_PORT` vari You can also pass an option `-c ` to control the number of cores that spark-shell uses on the cluster. -# Launching Applications Inside the Cluster +# Launching Compiled Spark Applications -You may also run your application entirely inside of the cluster by submitting your application driver using the submission client. The syntax for submitting applications is as follows: +Spark supports two deploy modes. Spark applications may run with the driver inside the client process or entirely inside the cluster. +The spark-submit script described in the [cluster mode overview](cluster-overview.html) provides the most straightforward way to submit a compiled Spark application to the cluster in either deploy mode. For info on the lower-level invocations used to launch an app inside the cluster, read ahead. + +## Launching Applications Inside the Cluster ./bin/spark-class org.apache.spark.deploy.Client launch [client-options] \ diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java index 667c72f379e71..cd8879ff886e2 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.examples; +import java.util.regex.Pattern; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -24,11 +25,9 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import java.util.Arrays; -import java.util.regex.Pattern; - /** * Logistic regression based classification using ML Lib. */ @@ -47,14 +46,10 @@ public LabeledPoint call(String line) { for (int i = 0; i < tok.length; ++i) { x[i] = Double.parseDouble(tok[i]); } - return new LabeledPoint(y, x); + return new LabeledPoint(y, Vectors.dense(x)); } } - public static void printWeights(double[] a) { - System.out.println(Arrays.toString(a)); - } - public static void main(String[] args) { if (args.length != 4) { System.err.println("Usage: JavaLR "); @@ -80,8 +75,7 @@ public static void main(String[] args) { LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), iterations, stepSize); - System.out.print("Final w: "); - printWeights(model.weights()); + System.out.print("Final w: " + model.weights()); System.exit(0); } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index c8ecbb8e41a86..0095cb8425456 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -53,7 +53,6 @@ object LocalALS { for (i <- 0 until M; j <- 0 until U) { r.set(i, j, blas.ddot(ms(i), us(j))) } - //println("R: " + r) blas.daxpy(-1, targetR, r) val sumSqs = r.aggregate(Functions.plus, Functions.square) sqrt(sumSqs / (M * U)) diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 73b0e216cac98..1fdb324b89f3a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -61,7 +61,7 @@ object SimpleSkewedGroupByTest { println("RESULT: " + pairs1.groupByKey(numReducers).count) // Print how many keys each reducer got (for debugging) - //println("RESULT: " + pairs1.groupByKey(numReducers) + // println("RESULT: " + pairs1.groupByKey(numReducers) // .map{case (k,v) => (k, v.size)} // .collectAsMap) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index ce4b3c8451e00..f59ab7e7cc24a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -54,7 +54,6 @@ object SparkALS { for (i <- 0 until M; j <- 0 until U) { r.set(i, j, blas.ddot(ms(i), us(j))) } - //println("R: " + r) blas.daxpy(-1, targetR, r) val sumSqs = r.aggregate(Functions.plus, Functions.square) sqrt(sumSqs / (M * U)) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index cf1fc3e808c76..e698b9bf376e1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -34,8 +34,6 @@ object SparkHdfsLR { case class DataPoint(x: Vector, y: Double) def parsePoint(line: String): DataPoint = { - //val nums = line.split(' ').map(_.toDouble) - //return DataPoint(new Vector(nums.slice(1, D+1)), nums(0)) val tok = new java.util.StringTokenizer(line, " ") var y = tok.nextToken.toDouble var x = new Array[Double](D) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index 62d3a52615584..a22e64ca3ce45 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -168,7 +168,7 @@ object ActorWordCount { Props(new SampleActorReceiver[String]("akka.tcp://test@%s:%s/user/FeederActor".format( host, port.toInt))), "SampleReceiver") - //compute wordcount + // compute wordcount lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() ssc.start() diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala index 35be7ffa1e872..35f8f885f8f0e 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala @@ -88,7 +88,7 @@ object ZeroMQWordCount { def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator - //For this stream, a zeroMQ publisher should be running. + // For this stream, a zeroMQ publisher should be running. val lines = ZeroMQUtils.createStream(ssc, url, Subscribe(topic), bytesToStringIterator _) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala index 0ac46c31c24c8..251f65fe4df9c 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala @@ -21,7 +21,7 @@ import java.net.ServerSocket import java.io.PrintWriter import util.Random -/** Represents a page view on a website with associated dimension data.*/ +/** Represents a page view on a website with associated dimension data. */ class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) extends Serializable { override def toString() : String = { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index ce3ef47cfe4bc..34012b846e21e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -127,7 +127,7 @@ class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { } /** A NetworkReceiver which listens for events using the - * Flume Avro interface.*/ + * Flume Avro interface. */ private[streaming] class FlumeReceiver( host: String, diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 6acba25f44c0a..a538c38dc4d6f 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -44,7 +44,7 @@ private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String, case m: ZMQMessage => logDebug("Received message for:" + m.frame(0)) - //We ignore first frame for processing as it is the topic + // We ignore first frame for processing as it is the topic val bytes = m.frames.tail pushBlock(bytesToObjects(bytes)) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index f2296a865e1b3..6d04bf790e3a5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -45,7 +45,8 @@ class EdgeRDD[@specialized ED: ClassTag]( partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { - firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context).next._2.iterator + val p = firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context) + p.next._2.iterator.map(_.copy()) } override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index fea43c3b2bbf1..dfc6a801587d2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -27,12 +27,12 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * The source vertex attribute */ - var srcAttr: VD = _ //nullValue[VD] + var srcAttr: VD = _ // nullValue[VD] /** * The destination vertex attribute */ - var dstAttr: VD = _ //nullValue[VD] + var dstAttr: VD = _ // nullValue[VD] /** * Set the edge properties of this triplet. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 57fa5eefd5e09..2e05f5d4e4969 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -56,6 +56,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * Construct a new edge partition by applying the function f to all * edges in this partition. * + * Be careful not to keep references to the objects passed to `f`. + * To improve GC performance the same object is re-used for each call. + * * @param f a function from an edge to a new attribute * @tparam ED2 the type of the new attribute * @return a new edge partition with the result of the function `f` @@ -84,12 +87,12 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * order of the edges returned by `EdgePartition.iterator` and * should return attributes equal to the number of edges. * - * @param f a function from an edge to a new attribute + * @param iter an iterator for the new attribute values * @tparam ED2 the type of the new attribute - * @return a new edge partition with the result of the function `f` - * applied to each edge + * @return a new edge partition with the attribute values replaced */ def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2] = { + // Faster than iter.toArray, because the expected size is known. val newData = new Array[ED2](data.size) var i = 0 while (iter.hasNext) { @@ -188,6 +191,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) /** * Get an iterator over the edges in this partition. * + * Be careful not to keep references to the objects from this iterator. + * To improve GC performance the same object is re-used in `next()`. + * * @return an iterator over edges in the partition */ def iterator = new Iterator[Edge[ED]] { @@ -216,6 +222,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) /** * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The * cluster must start at position `index`. + * + * Be careful not to keep references to the objects from this iterator. To improve GC performance + * the same object is re-used in `next()`. */ private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala index 886c250d7cffd..220a89d73d711 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala @@ -37,20 +37,15 @@ class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( // Current position in the array. private var pos = 0 - // A triplet object that this iterator.next() call returns. We reuse this object to avoid - // allocating too many temporary Java objects. - private val triplet = new EdgeTriplet[VD, ED] - private val vmap = new PrimitiveKeyOpenHashMap[VertexId, VD](vidToIndex, vertexArray) override def hasNext: Boolean = pos < edgePartition.size override def next() = { + val triplet = new EdgeTriplet[VD, ED] triplet.srcId = edgePartition.srcIds(pos) - // assert(vmap.containsKey(e.src.id)) triplet.srcAttr = vmap(triplet.srcId) triplet.dstId = edgePartition.dstIds(pos) - // assert(vmap.containsKey(e.dst.id)) triplet.dstAttr = vmap(triplet.dstId) triplet.attr = edgePartition.data(pos) pos += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 43ac11d8957f6..c2b510a31ee3f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -190,9 +190,9 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( new GraphImpl(vertices, newETable, routingTable, replicatedVertexView) } - ////////////////////////////////////////////////////////////////////////////////////////////////// + // /////////////////////////////////////////////////////////////////////////////////////////////// // Lower level transformation methods - ////////////////////////////////////////////////////////////////////////////////////////////////// + // /////////////////////////////////////////////////////////////////////////////////////////////// override def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index fe6fe76defdc5..9d4f3750cb8e4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -45,7 +45,7 @@ class VertexBroadcastMsg[@specialized(Int, Long, Double, Boolean) T]( * @param data value to send */ private[graphx] -class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T]( +class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T]( @transient var partition: PartitionID, var data: T) extends Product2[PartitionID, T] with Serializable { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index 34a145e01818f..2f2c524df6394 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -298,7 +298,6 @@ abstract class ShuffleSerializationStream(s: OutputStream) extends Serialization s.write(v.toInt) } - //def writeDouble(v: Double): Unit = writeUnsignedVarLong(java.lang.Double.doubleToLongBits(v)) def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v)) override def flush(): Unit = s.flush() @@ -391,7 +390,6 @@ abstract class ShuffleDeserializationStream(s: InputStream) extends Deserializat (s.read() & 0xFF) } - //def readDouble(): Double = java.lang.Double.longBitsToDouble(readUnsignedVarLong()) def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong()) override def close(): Unit = s.close() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala index 014a7335f85cc..087b1156f690b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/BytecodeUtils.scala @@ -65,7 +65,7 @@ private[graphx] object BytecodeUtils { val finder = new MethodInvocationFinder(c.getName, m) getClassReader(c).accept(finder, 0) for (classMethod <- finder.methodsInvoked) { - //println(classMethod) + // println(classMethod) if (classMethod._1 == targetClass && classMethod._2 == targetMethod) { return true } else if (!seen.contains(classMethod)) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index f841846c0e510..a3c8de3f9068f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -123,7 +123,7 @@ object GraphGenerators { * the dimensions of the adjacency matrix */ private def addEdge(numVertices: Int): Edge[Int] = { - //val (src, dst) = chooseCell(numVertices/2.0, numVertices/2.0, numVertices/2.0) + // val (src, dst) = chooseCell(numVertices/2.0, numVertices/2.0, numVertices/2.0) val v = math.round(numVertices.toFloat/2.0).toInt val (src, dst) = chooseCell(v, v, v) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala new file mode 100644 index 0000000000000..9cbb2d2acdc2d --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite + +import org.apache.spark.graphx._ + +class EdgeTripletIteratorSuite extends FunSuite { + test("iterator.toList") { + val builder = new EdgePartitionBuilder[Int] + builder.add(1, 2, 0) + builder.add(1, 3, 0) + builder.add(1, 4, 0) + val vidmap = new VertexIdToIndexMap + vidmap.add(1) + vidmap.add(2) + vidmap.add(3) + vidmap.add(4) + val vs = Array.fill(vidmap.capacity)(0) + val iter = new EdgeTripletIterator[Int, Int](vidmap, vs, builder.toEdgePartition) + val result = iter.toList.map(et => (et.srcId, et.dstId)) + assert(result === Seq((1, 2), (1, 3), (1, 4))) + } +} diff --git a/make-distribution.sh b/make-distribution.sh index 6bc6819d8da92..5c780fcbda863 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -128,7 +128,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then TACHYON_VERSION="0.4.1" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" - TMPD=`mktemp -d` + TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` pushd $TMPD > /dev/null echo "Fetchting tachyon tgz" @@ -139,7 +139,13 @@ if [ "$SPARK_TACHYON" == "true" ]; then mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" cp -r "tachyon-${TACHYON_VERSION}"/src/main/java/tachyon/web/resources "$DISTDIR/tachyon/src/main/java/tachyon/web" - sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../../jars/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" + + if [[ `uname -a` == Darwin* ]]; then + # need to run sed differently on osx + nl=$'\n'; sed -i "" -e "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\\$nl export TACHYON_JAR=\$TACHYON_HOME/../jars/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" + else + sed -i "s|export TACHYON_JAR=\$TACHYON_HOME/target/\(.*\)|# This is set for spark's make-distribution\n export TACHYON_JAR=\$TACHYON_HOME/../jars/\1|" "$DISTDIR/tachyon/libexec/tachyon-config.sh" + fi popd > /dev/null rm -rf $TMPD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 3449c698da60b..2df5b0d02b699 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -110,16 +110,16 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, - dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): - java.util.LinkedList[java.lang.Object] = { + dataBytesJRDD: JavaRDD[Array[Byte]], + initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { val x = deserializeDoubleVector(xBytes) - LabeledPoint(x(0), x.slice(1, x.length)) + LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) }) val initialWeights = deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights)) + ret.add(serializeDoubleVector(model.weights.toArray)) ret.add(model.intercept: java.lang.Double) ret } @@ -127,75 +127,127 @@ class PythonMLLibAPI extends Serializable { /** * Java stub for Python mllib LinearRegressionWithSGD.train() */ - def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], - numIterations: Int, stepSize: Double, miniBatchFraction: Double, + def trainLinearRegressionModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - LinearRegressionWithSGD.train(data, numIterations, stepSize, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + LinearRegressionWithSGD.train( + data, + numIterations, + stepSize, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib LassoWithSGD.train() */ - def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, + def trainLassoModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - LassoWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + LassoWithSGD.train( + data, + numIterations, + stepSize, + regParam, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib RidgeRegressionWithSGD.train() */ - def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, + def trainRidgeModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + RidgeRegressionWithSGD.train( + data, + numIterations, + stepSize, + regParam, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib SVMWithSGD.train() */ - def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, + def trainSVMModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - SVMWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + SVMWithSGD.train( + data, + numIterations, + stepSize, + regParam, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib LogisticRegressionWithSGD.train() */ - def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], - numIterations: Int, stepSize: Double, miniBatchFraction: Double, + def trainLogisticRegressionModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - LogisticRegressionWithSGD.train(data, numIterations, stepSize, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + LogisticRegressionWithSGD.train( + data, + numIterations, + stepSize, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double) - : java.util.List[java.lang.Object] = - { + def trainNaiveBayes( + dataBytesJRDD: JavaRDD[Array[Byte]], + lambda: Double): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { val x = deserializeDoubleVector(xBytes) - LabeledPoint(x(0), x.slice(1, x.length)) + LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) }) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleVector(model.labels)) ret.add(serializeDoubleVector(model.pi)) ret.add(serializeDoubleMatrix(model.theta)) ret @@ -204,9 +256,12 @@ class PythonMLLibAPI extends Serializable { /** * Java stub for Python mllib KMeans.train() */ - def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, - maxIterations: Int, runs: Int, initializationMode: String): - java.util.List[java.lang.Object] = { + def trainKMeansModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes))) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() @@ -259,8 +314,12 @@ class PythonMLLibAPI extends Serializable { * needs to be taken in the Python code to ensure it gets freed on exit; see * the Py4J documentation. */ - def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + def trainALSModel( + ratingsBytesJRDD: JavaRDD[Array[Byte]], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.train(ratings, rank, iterations, lambda, blocks) } @@ -271,8 +330,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. */ - def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + def trainImplicitALSModel( + ratingsBytesJRDD: JavaRDD[Array[Byte]], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int, + alpha: Double): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 391f5b9b7a7de..bd10e2e9e10e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,22 +17,27 @@ package org.apache.spark.mllib.classification +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD +/** + * Represents a classification model that predicts to which of a set of categories an example + * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. + */ trait ClassificationModel extends Serializable { /** * Predict values for the given data set using the model trained. * * @param testData RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction + * @return an RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] + def predict(testData: RDD[Vector]): RDD[Double] /** * Predict values for a single data point using the model trained. * * @param testData array representing a single data point - * @return Int prediction from the trained model + * @return predicted category from the trained model */ - def predict(testData: Array[Double]): Double + def predict(testData: Vector): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index a481f522761e2..798f3a5c94740 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,16 +17,12 @@ package org.apache.spark.mllib.classification -import scala.math.round - import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.rdd.RDD /** * Classification model trained using Logistic Regression. @@ -35,15 +31,38 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LogisticRegressionModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + + private var threshold: Option[Double] = Some(0.5) + + /** + * Sets the threshold that separates positive predictions from negative predictions. An example + * with prediction score greater than or equal to this threshold is identified as an positive, + * and negative otherwise. The default value is 0.5. + */ + def setThreshold(threshold: Double): this.type = { + this.threshold = Some(threshold) + this + } - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + /** + * Clears the threshold so that `predict` will output raw prediction scores. + */ + def clearThreshold(): this.type = { + threshold = None + this + } + + override def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { - val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept - round(1.0/ (1.0 + math.exp(margin * -1))) + val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + val score = 1.0/ (1.0 + math.exp(-margin)) + threshold match { + case Some(t) => if (score < t) 0.0 else 1.0 + case None => score + } } } @@ -56,16 +75,15 @@ class LogisticRegressionWithSGD private ( var numIterations: Int, var regParam: Double, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LogisticRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { val gradient = new LogisticGradient() val updater = new SimpleUpdater() override val optimizer = new GradientDescent(gradient, updater) - .setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) override val validators = List(DataValidators.classificationLabels) /** @@ -73,7 +91,7 @@ class LogisticRegressionWithSGD private ( */ def this() = this(1.0, 100, 0.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } } @@ -105,11 +123,9 @@ object LogisticRegressionWithSGD { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input, initialWeights) + initialWeights: Vector): LogisticRegressionModel = { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) + .run(input, initialWeights) } /** @@ -128,11 +144,9 @@ object LogisticRegressionWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - miniBatchFraction: Double) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input) + miniBatchFraction: Double): LogisticRegressionModel = { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) + .run(input) } /** @@ -150,9 +164,7 @@ object LogisticRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int, - stepSize: Double) - : LogisticRegressionModel = - { + stepSize: Double): LogisticRegressionModel = { train(input, numIterations, stepSize, 1.0) } @@ -168,9 +180,7 @@ object LogisticRegressionWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : LogisticRegressionModel = - { + numIterations: Int): LogisticRegressionModel = { train(input, numIterations, 1.0, 1.0) } @@ -183,7 +193,7 @@ object LogisticRegressionWithSGD { val sc = new SparkContext(args(0), "LogisticRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 6539b2f339465..e956185319a69 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.classification -import scala.collection.mutable +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.jblas.DoubleMatrix - -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD /** * Model for Naive Bayes Classifiers. @@ -32,19 +32,28 @@ import org.apache.spark.mllib.util.MLUtils * @param pi Log of class priors, whose dimension is C. * @param theta Log of class conditional probabilities, whose dimension is CxD. */ -class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]]) - extends ClassificationModel with Serializable { - - // Create a column vector that can be used for predictions - private val _pi = new DoubleMatrix(pi.length, 1, pi: _*) - private val _theta = new DoubleMatrix(theta) +class NaiveBayesModel( + val labels: Array[Double], + val pi: Array[Double], + val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { + + private val brzPi = new BDV[Double](pi) + private val brzTheta = new BDM[Double](theta.length, theta(0).length) + + var i = 0 + while (i < theta.length) { + var j = 0 + while (j < theta(i).length) { + brzTheta(i, j) = theta(i)(j) + j += 1 + } + i += 1 + } - def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict) + override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict) - def predict(testData: Array[Double]): Double = { - val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*) - val result = _pi.add(_theta.mmul(dataMatrix)) - result.argmax() + override def predict(testData: Vector): Double = { + labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } } @@ -56,9 +65,8 @@ class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]]) * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). */ -class NaiveBayes private (var lambda: Double) - extends Serializable with Logging -{ +class NaiveBayes private (var lambda: Double) extends Serializable with Logging { + def this() = this(1.0) /** Set the smoothing parameter. Default: 1.0. */ @@ -70,45 +78,42 @@ class NaiveBayes private (var lambda: Double) /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * - * @param data RDD of (label, array of features) pairs. + * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ def run(data: RDD[LabeledPoint]) = { - // Aggregates all sample points to driver side to get sample count and summed feature vector - // for each label. The shape of `zeroCombiner` & `aggregated` is: - // - // label: Int -> (count: Int, featuresSum: DoubleMatrix) - val zeroCombiner = mutable.Map.empty[Int, (Int, DoubleMatrix)] - val aggregated = data.aggregate(zeroCombiner)({ (combiner, point) => - point match { - case LabeledPoint(label, features) => - val (count, featuresSum) = combiner.getOrElse(label.toInt, (0, DoubleMatrix.zeros(1))) - val fs = new DoubleMatrix(features.length, 1, features: _*) - combiner += label.toInt -> (count + 1, featuresSum.addi(fs)) - } - }, { (lhs, rhs) => - for ((label, (c, fs)) <- rhs) { - val (count, featuresSum) = lhs.getOrElse(label, (0, DoubleMatrix.zeros(1))) - lhs(label) = (count + c, featuresSum.addi(fs)) + // Aggregates term frequencies per label. + // TODO: Calling combineByKey and collect creates two stages, we can implement something + // TODO: similar to reduceByKeyLocally to save one stage. + val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( + createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), + mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), + mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => + (c1._1 + c2._1, c1._2 += c2._2) + ).collect() + val numLabels = aggregated.length + var numDocuments = 0L + aggregated.foreach { case (_, (n, _)) => + numDocuments += n + } + val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } + val labels = new Array[Double](numLabels) + val pi = new Array[Double](numLabels) + val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) + val piLogDenom = math.log(numDocuments + numLabels * lambda) + var i = 0 + aggregated.foreach { case (label, (n, sumTermFreqs)) => + labels(i) = label + val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + pi(i) = math.log(n + lambda) - piLogDenom + var j = 0 + while (j < numFeatures) { + theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + j += 1 } - lhs - }) - - // Kinds of label - val C = aggregated.size - // Total sample count - val N = aggregated.values.map(_._1).sum - - val pi = new Array[Double](C) - val theta = new Array[Array[Double]](C) - val piLogDenom = math.log(N + C * lambda) - - for ((label, (count, fs)) <- aggregated) { - val thetaLogDenom = math.log(fs.sum() + fs.length * lambda) - pi(label) = math.log(count + lambda) - piLogDenom - theta(label) = fs.toArray.map(f => math.log(f + lambda) - thetaLogDenom) + i += 1 } - new NaiveBayesModel(pi, theta) + new NaiveBayesModel(labels, pi, theta) } } @@ -158,8 +163,9 @@ object NaiveBayes { } else { NaiveBayes.train(data, args(2).toDouble) } - println("Pi: " + model.pi.mkString("[", ", ", "]")) - println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]")) + + println("Pi\n: " + model.pi) + println("Theta:\n" + model.theta) sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 6dff29dfb45cc..e31a08899f8bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -18,13 +18,11 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.rdd.RDD /** * Model for Support Vector Machines (SVMs). @@ -33,15 +31,37 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class SVMModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + + private var threshold: Option[Double] = Some(0.0) + + /** + * Sets the threshold that separates positive predictions from negative predictions. An example + * with prediction score greater than or equal to this threshold is identified as an positive, + * and negative otherwise. The default value is 0.0. + */ + def setThreshold(threshold: Double): this.type = { + this.threshold = Some(threshold) + this + } - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + /** + * Clears the threshold so that `predict` will output raw prediction scores. + */ + def clearThreshold(): this.type = { + threshold = None + this + } + + override def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { - val margin = dataMatrix.dot(weightMatrix) + intercept - if (margin < 0) 0.0 else 1.0 + val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + threshold match { + case Some(t) => if (margin < 0) 0.0 else 1.0 + case None => margin + } } } @@ -71,7 +91,7 @@ class SVMWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + def createModel(weights: Vector, intercept: Double) = { new SVMModel(weights, intercept) } } @@ -103,11 +123,9 @@ object SVMWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : SVMModel = - { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) + initialWeights: Vector): SVMModel = { + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction) + .run(input, initialWeights) } /** @@ -127,9 +145,7 @@ object SVMWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double) - : SVMModel = - { + miniBatchFraction: Double): SVMModel = { new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -149,9 +165,7 @@ object SVMWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double) - : SVMModel = - { + regParam: Double): SVMModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -165,11 +179,7 @@ object SVMWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : SVMModel = - { + def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 1.0, 1.0) } @@ -181,7 +191,8 @@ object SVMWithSGD { val sc = new SparkContext(args(0), "SVM") val data = MLUtils.loadLabeledData(sc, args(1)) val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b412738e3f00a..a78503df3134d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -42,8 +42,7 @@ class KMeans private ( var runs: Int, var initializationMode: String, var initializationSteps: Int, - var epsilon: Double) - extends Serializable with Logging { + var epsilon: Double) extends Serializable with Logging { def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) /** Set the number of clusters to create (k). Default: 2. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 01c1501548f87..2cea58cd3fd22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -54,6 +54,12 @@ trait Vector extends Serializable { * Converts the instance to a breeze vector. */ private[mllib] def toBreeze: BV[Double] + + /** + * Gets the value of the ith element. + * @param i index + */ + private[mllib] def apply(i: Int): Double = toBreeze(i) } /** @@ -145,6 +151,8 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + + override def apply(i: Int) = values(i) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 82124703da6cd..20654284965ed 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -17,7 +17,9 @@ package org.apache.spark.mllib.optimization -import org.jblas.DoubleMatrix +import breeze.linalg.{axpy => brzAxpy} + +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * Class used to compute the gradient for a loss function, given a single data point. @@ -26,17 +28,26 @@ abstract class Gradient extends Serializable { /** * Compute the gradient and loss given the features of a single data point. * - * @param data - Feature values for one data point. Column matrix of size dx1 - * where d is the number of features. - * @param label - Label for this data item. - * @param weights - Column matrix containing weights for every feature. + * @param data features for one data point + * @param label label for this data point + * @param weights weights/coefficients corresponding to features * - * @return A tuple of 2 elements. The first element is a column matrix containing the computed - * gradient and the second element is the loss computed at this data point. + * @return (gradient: Vector, loss: Double) + */ + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + + /** + * Compute the gradient and loss given the features of a single data point, + * add the gradient to a provided vector to avoid creating new objects, and return loss. * + * @param data features for one data point + * @param label label for this data point + * @param weights weights/coefficients corresponding to features + * @param cumGradient the computed gradient will be added to this vector + * + * @return loss */ - def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) + def compute(data: Vector, label: Double, weights: Vector, cumGradient: Vector): Double } /** @@ -44,12 +55,12 @@ abstract class Gradient extends Serializable { * See also the documentation for the precise formulation. */ class LogisticGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val margin: Double = -1.0 * data.dot(weights) + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val margin: Double = -1.0 * brzWeights.dot(brzData) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - val gradient = data.mul(gradientMultiplier) + val gradient = brzData * gradientMultiplier val loss = if (label > 0) { math.log(1 + math.exp(margin)) @@ -57,7 +68,26 @@ class LogisticGradient extends Gradient { math.log(1 + math.exp(margin)) - margin } - (gradient, loss) + (Vectors.fromBreeze(gradient), loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val margin: Double = -1.0 * brzWeights.dot(brzData) + val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label + + brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) + + if (label > 0) { + math.log(1 + math.exp(margin)) + } else { + math.log(1 + math.exp(margin)) - margin + } } } @@ -68,14 +98,28 @@ class LogisticGradient extends Gradient { * See also the documentation for the precise formulation. */ class LeastSquaresGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val diff: Double = data.dot(weights) - label - + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val diff = brzWeights.dot(brzData) - label val loss = diff * diff - val gradient = data.mul(2.0 * diff) + val gradient = brzData * (2.0 * diff) - (gradient, loss) + (Vectors.fromBreeze(gradient), loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val diff = brzWeights.dot(brzData) - label + + brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze) + + diff * diff } } @@ -85,19 +129,40 @@ class LeastSquaresGradient extends Gradient { * NOTE: This assumes that the labels are {0,1} */ class HingeGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val dotProduct = brzWeights.dot(brzData) + + // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + val labelScaled = 2 * label - 1.0 + + if (1.0 > labelScaled * dotProduct) { + (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct) + } else { + (Vectors.dense(new Array[Double](weights.size)), 0.0) + } + } - val dotProduct = data.dot(weights) + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val dotProduct = brzWeights.dot(brzData) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { - (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) + brzAxpy(-labelScaled, brzData, cumGradient.toBreeze) + 1.0 - labelScaled * dotProduct } else { - (DoubleMatrix.zeros(1, weights.length), 0.0) + 0.0 } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index b967b22e818d3..d0777ffd63ff8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD +import scala.collection.mutable.ArrayBuffer -import org.jblas.DoubleMatrix +import breeze.linalg.{Vector => BV, DenseVector => BDV} -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * Class used to solve an optimization problem using Gradient Descent. @@ -91,18 +92,16 @@ class GradientDescent(var gradient: Gradient, var updater: Updater) this } - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) - : Array[Double] = { - - val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD( - data, - gradient, - updater, - stepSize, - numIterations, - regParam, - miniBatchFraction, - initialWeights) + def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { + val (weights, _) = GradientDescent.runMiniBatchSGD( + data, + gradient, + updater, + stepSize, + numIterations, + regParam, + miniBatchFraction, + initialWeights) weights } @@ -133,14 +132,14 @@ object GradientDescent extends Logging { * stochastic loss computed for every iteration. */ def runMiniBatchSGD( - data: RDD[(Double, Array[Double])], + data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, stepSize: Double, numIterations: Int, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { + initialWeights: Vector): (Vector, Array[Double]) = { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) @@ -148,24 +147,27 @@ object GradientDescent extends Logging { val miniBatchSize = nexamples * miniBatchFraction // Initialize weights as a column vector - var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) + var weights = Vectors.dense(initialWeights.toArray) /** * For the first iteration, the regVal will be initialized as sum of sqrt of * weights if it's L2 update; for L1 update; the same logic is followed. */ var regVal = updater.compute( - weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2 + weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 for (i <- 1 to numIterations) { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map { - case (y, features) => - val featuresCol = new DoubleMatrix(features.length, 1, features:_*) - val (grad, loss) = gradient.compute(featuresCol, y, weights) - (grad, loss) - }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) + val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) + .aggregate((BDV.zeros[Double](weights.size), 0.0))( + seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => + val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad)) + (grad, loss + l) + }, + combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + (grad1 += grad2, loss1 + loss2) + }) /** * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration @@ -173,7 +175,7 @@ object GradientDescent extends Logging { */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) val update = updater.compute( - weights, gradientSum.div(miniBatchSize), stepSize, i, regParam) + weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) weights = update._1 regVal = update._2 } @@ -181,6 +183,6 @@ object GradientDescent extends Logging { logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( stochasticLossHistory.takeRight(10).mkString(", "))) - (weights.toArray, stochasticLossHistory.toArray) + (weights, stochasticLossHistory.toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 94d30b56f212b..f9ce908a5f3b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -19,11 +19,12 @@ package org.apache.spark.mllib.optimization import org.apache.spark.rdd.RDD -trait Optimizer { +import org.apache.spark.mllib.linalg.Vector + +trait Optimizer extends Serializable { /** * Solve the provided convex optimization problem. */ - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double] - + def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index bf8f731459e99..3b7754cd7ac28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -18,7 +18,10 @@ package org.apache.spark.mllib.optimization import scala.math._ -import org.jblas.DoubleMatrix + +import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} + +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * Class used to perform steps (weight update) using Gradient Descent methods. @@ -47,8 +50,12 @@ abstract class Updater extends Serializable { * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, * and the second element is the regularization value computed using updated weights. */ - def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, - regParam: Double): (DoubleMatrix, Double) + def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) } /** @@ -56,11 +63,17 @@ abstract class Updater extends Serializable { * Uses a step-size decreasing with the square root of the number of iterations. */ class SimpleUpdater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) - (weightsOld.sub(step), 0) + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + + (Vectors.fromBreeze(brzWeights), 0) } } @@ -83,19 +96,26 @@ class SimpleUpdater extends Updater { * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) */ class L1Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) // Take gradient step - val newWeights = weightsOld.sub(step) + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize - (0 until newWeights.length).foreach { i => - val wi = newWeights.get(i) - newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) + var i = 0 + while (i < brzWeights.length) { + val wi = brzWeights(i) + brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal) + i += 1 } - (newWeights, newWeights.norm1 * regParam) + + (Vectors.fromBreeze(brzWeights), brzNorm(brzWeights, 1.0) * regParam) } } @@ -105,16 +125,23 @@ class L1Updater extends Updater { * Uses a step-size decreasing with the square root of the number of iterations. */ class SquaredL2Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { // add up both updates from the gradient of the loss (= step) as well as // the gradient of the regularizer (= regParam * weightsOld) // w' = w - thisIterStepSize * (gradient + regParam * w) // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient - val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step) - (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam) + val thisIterStepSize = stepSize / math.sqrt(iter) + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzWeights :*= (1.0 - thisIterStepSize * regParam) + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + val norm = brzNorm(brzWeights, 2.0) + + (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 3e1ed91bf6729..80dc0f12ff84f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,11 +17,12 @@ package org.apache.spark.mllib.regression +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} + import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** * GeneralizedLinearModel (GLM) represents a model trained using @@ -31,12 +32,9 @@ import org.jblas.DoubleMatrix * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double) +abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) extends Serializable { - // Create a column vector that can be used for predictions - private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) - /** * Predict the result given a data point and the weights learned. * @@ -44,8 +42,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double): Double + protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double /** * Predict values for the given data set using the model trained. @@ -53,16 +50,13 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] = { + def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. - val localWeights = weightsMatrix + val localWeights = weights val localIntercept = intercept - testData.map { x => - val dataMatrix = new DoubleMatrix(1, x.length, x:_*) - predictPoint(dataMatrix, localWeights, localIntercept) - } + testData.map(v => predictPoint(v, localWeights, localIntercept)) } /** @@ -71,14 +65,13 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param testData array representing a single data point * @return Double prediction from the trained model */ - def predict(testData: Array[Double]): Double = { - val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - predictPoint(dataMat, weightsMatrix, intercept) + def predict(testData: Vector): Double = { + predictPoint(testData, weights, intercept) } } /** - * GeneralizedLinearAlgorithm implements methods to train a Genearalized Linear Model (GLM). + * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM). * This class should be extended with an Optimizer to create a new GLM. */ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] @@ -88,6 +81,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val optimizer: Optimizer + /** Whether to add intercept (default: true). */ protected var addIntercept: Boolean = true protected var validateData: Boolean = true @@ -95,7 +89,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Create a model given the weights and intercept */ - protected def createModel(weights: Array[Double], intercept: Double): M + protected def createModel(weights: Vector, intercept: Double): M /** * Set if the algorithm should add an intercept. Default true. @@ -117,17 +111,27 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. */ - def run(input: RDD[LabeledPoint]) : M = { - val nfeatures: Int = input.first().features.length - val initialWeights = new Array[Double](nfeatures) + def run(input: RDD[LabeledPoint]): M = { + val numFeatures: Int = input.first().features.size + val initialWeights = Vectors.dense(new Array[Double](numFeatures)) run(input, initialWeights) } + /** Prepends one to the input vector. */ + private def prependOne(vector: Vector): Vector = { + val vector1 = vector.toBreeze match { + case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv) + case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv) + case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + Vectors.fromBreeze(vector1) + } + /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. */ - def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { + def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { // Check the data properties before running the optimizer if (validateData && !validators.forall(func => func(input))) { @@ -136,27 +140,26 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features)) + input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features))) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - 0.0 +: initialWeights + prependOne(initialWeights) } else { initialWeights } val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val (intercept, weights) = if (addIntercept) { - (weightsWithIntercept(0), weightsWithIntercept.tail) - } else { - (0.0, weightsWithIntercept) - } - - logInfo("Final weights " + weights.mkString(",")) - logInfo("Final intercept " + intercept) + val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0 + val weights = + if (addIntercept) { + Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size)) + } else { + weightsWithIntercept + } createModel(weights, intercept) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 1a18292fe3f3b..3deab1ab785b9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,14 +17,16 @@ package org.apache.spark.mllib.regression +import org.apache.spark.mllib.linalg.Vector + /** * Class that represents the features and labels of a data point. * * @param label Label for this data point. * @param features List of features for this data point. */ -case class LabeledPoint(label: Double, features: Array[Double]) { +case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { - "LabeledPoint(%s, %s)".format(label, features.mkString("[", ", ", "]")) + "LabeledPoint(%s, %s)".format(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index be63ce8538fef..25920d0dc976e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.regression -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix +import org.apache.spark.rdd.RDD /** * Regression model trained using Lasso. @@ -31,16 +30,16 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LassoModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint( - dataMatrix: DoubleMatrix, - weightMatrix: DoubleMatrix, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double): Double = { - dataMatrix.dot(weightMatrix) + intercept + weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } } @@ -57,8 +56,7 @@ class LassoWithSGD private ( var numIterations: Int, var regParam: Double, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LassoModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LassoModel] with Serializable { val gradient = new LeastSquaresGradient() val updater = new L1Updater() @@ -70,10 +68,6 @@ class LassoWithSGD private ( // We don't want to penalize the intercept, so set this to false. super.setIntercept(false) - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - /** * Construct a Lasso object with default parameters */ @@ -85,36 +79,8 @@ class LassoWithSGD private ( this } - override def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) - - new LassoModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : LassoModel = - { - val nfeatures: Int = input.first.features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) + override protected def createModel(weights: Vector, intercept: Double) = { + new LassoModel(weights, intercept) } } @@ -144,11 +110,9 @@ object LassoWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : LassoModel = - { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) + initialWeights: Vector): LassoModel = { + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction) + .run(input, initialWeights) } /** @@ -168,9 +132,7 @@ object LassoWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double) - : LassoModel = - { + miniBatchFraction: Double): LassoModel = { new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -190,9 +152,7 @@ object LassoWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double) - : LassoModel = - { + regParam: Double): LassoModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -208,9 +168,7 @@ object LassoWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : LassoModel = - { + numIterations: Int): LassoModel = { train(input, numIterations, 1.0, 1.0, 1.0) } @@ -222,7 +180,8 @@ object LassoWithSGD { val sc = new SparkContext(args(0), "Lasso") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index f5f15d1a33f4d..9ed927994e795 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils -import org.jblas.DoubleMatrix - /** * Regression model trained using LinearRegression. * @@ -31,15 +30,15 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LinearRegressionModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint( - dataMatrix: DoubleMatrix, - weightMatrix: DoubleMatrix, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double): Double = { - dataMatrix.dot(weightMatrix) + intercept + weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } } @@ -69,7 +68,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - override def createModel(weights: Array[Double], intercept: Double) = { + override protected def createModel(weights: Vector, intercept: Double) = { new LinearRegressionModel(weights, intercept) } } @@ -98,11 +97,9 @@ object LinearRegressionWithSGD { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : LinearRegressionModel = - { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input, - initialWeights) + initialWeights: Vector): LinearRegressionModel = { + new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + .run(input, initialWeights) } /** @@ -120,9 +117,7 @@ object LinearRegressionWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - miniBatchFraction: Double) - : LinearRegressionModel = - { + miniBatchFraction: Double): LinearRegressionModel = { new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) } @@ -140,9 +135,7 @@ object LinearRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int, - stepSize: Double) - : LinearRegressionModel = - { + stepSize: Double): LinearRegressionModel = { train(input, numIterations, stepSize, 1.0) } @@ -158,9 +151,7 @@ object LinearRegressionWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : LinearRegressionModel = - { + numIterations: Int): LinearRegressionModel = { train(input, numIterations, 1.0, 1.0) } @@ -172,7 +163,7 @@ object LinearRegressionWithSGD { val sc = new SparkContext(args(0), "LinearRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 423afc32d665c..5e4b8a345b1c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector trait RegressionModel extends Serializable { /** @@ -26,7 +27,7 @@ trait RegressionModel extends Serializable { * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] + def predict(testData: RDD[Vector]): RDD[Double] /** * Predict values for a single data point using the model trained. @@ -34,5 +35,5 @@ trait RegressionModel extends Serializable { * @param testData array representing a single data point * @return Double prediction from the trained model */ - def predict(testData: Array[Double]): Double + def predict(testData: Vector): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index feb100f21888f..1f17d2107f940 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -21,8 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.linalg.Vector /** * Regression model trained using RidgeRegression. @@ -31,16 +30,16 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class RidgeRegressionModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint( - dataMatrix: DoubleMatrix, - weightMatrix: DoubleMatrix, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double): Double = { - dataMatrix.dot(weightMatrix) + intercept + weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } } @@ -57,8 +56,7 @@ class RidgeRegressionWithSGD private ( var numIterations: Int, var regParam: Double, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[RidgeRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable { val gradient = new LeastSquaresGradient() val updater = new SquaredL2Updater() @@ -71,10 +69,6 @@ class RidgeRegressionWithSGD private ( // We don't want to penalize the intercept in RidgeRegression, so set this to false. super.setIntercept(false) - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - /** * Construct a RidgeRegression object with default parameters */ @@ -86,36 +80,8 @@ class RidgeRegressionWithSGD private ( this } - override def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) - - new RidgeRegressionModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : RidgeRegressionModel = - { - val nfeatures: Int = input.first().features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) + override protected def createModel(weights: Vector, intercept: Double) = { + new RidgeRegressionModel(weights, intercept) } } @@ -144,9 +110,7 @@ object RidgeRegressionWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : RidgeRegressionModel = - { + initialWeights: Vector): RidgeRegressionModel = { new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run( input, initialWeights) } @@ -167,9 +131,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double) - : RidgeRegressionModel = - { + miniBatchFraction: Double): RidgeRegressionModel = { new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -188,9 +150,7 @@ object RidgeRegressionWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double) - : RidgeRegressionModel = - { + regParam: Double): RidgeRegressionModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -205,23 +165,22 @@ object RidgeRegressionWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : RidgeRegressionModel = - { + numIterations: Int): RidgeRegressionModel = { train(input, numIterations, 1.0, 1.0, 1.0) } def main(args: Array[String]) { if (args.length != 5) { - println("Usage: RidgeRegression " + - " ") + println("Usage: RidgeRegression " + + " ") System.exit(1) } val sc = new SparkContext(args(0), "RidgeRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala new file mode 100644 index 0000000000000..dee9594a9dd79 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -0,0 +1,1151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.util.control.Breaks._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * A class that implements a decision tree algorithm for classification and regression. It + * supports both continuous and categorical features. + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + */ +class DecisionTree private(val strategy: Strategy) extends Serializable with Logging { + + /** + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + + // Cache input RDD for speedup during multiple passes. + input.cache() + logDebug("algo = " + strategy.algo) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + logDebug("numSplits = " + bins(0).length) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + // the max number of nodes possible given the depth of the tree + val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + // Initialize an array to hold filters applied to points for each node. + val filters = new Array[List[Filter]](maxNumNodes) + // The filter at the top node is an empty list. + filters(0) = List() + // Initialize an array to hold parent impurity calculations for each node. + val parentImpurities = new Array[Double](maxNumNodes) + // dummy value for top node (updated during first split calculation) + val nodes = new Array[Node](maxNumNodes) + + + /* + * The main idea here is to perform level-wise training of the decision tree nodes thus + * reducing the passes over the data from l to log2(l) where l is the total number of nodes. + * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + * the sample is only used for the split calculation at the node if the sampled would have + * still survived the filters of the parent nodes. + */ + + // TODO: Convert for loop to while loop + breakable { + for (level <- 0 until maxDepth) { + + logDebug("#####################################") + logDebug("level = " + level) + logDebug("#####################################") + + // Find best split for all nodes at a level. + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, + level, filters, splits, bins) + + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + // Extract info for nodes at the current level. + extractNodeInfo(nodeSplitStats, level, index, nodes) + // Extract info for nodes at the next lower level. + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, + filters) + logDebug("final best split = " + nodeSplitStats._1) + } + require(scala.math.pow(2, level) == splitsStatsForLevel.length) + // Check whether all the nodes at the current level at leaves. + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) + logDebug("all leaf = " + allLeaf) + if (allLeaf) break // no more tree construction + } + } + + // Initialize the top or root node of the tree. + val topNode = nodes(0) + // Build the full tree using the node info calculated in the level-wise best split calculations. + topNode.build(nodes) + + new DecisionTreeModel(topNode, strategy.algo) + } + + /** + * Extract the decision tree node information for the given tree level and node index + */ + private def extractNodeInfo( + nodeSplitStats: (Split, InformationGainStats), + level: Int, + index: Int, + nodes: Array[Node]): Unit = { + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node + } + + /** + * Extract the decision tree node information for the children of the node + */ + private def extractInfoForLowerLevels( + level: Int, + index: Int, + maxDepth: Int, + nodeSplitStats: (Split, InformationGainStats), + parentImpurities: Array[Double], + filters: Array[List[Filter]]): Unit = { + // 0 corresponds to the left child node and 1 corresponds to the right child node. + // TODO: Convert to while loop + for (i <- 0 to 1) { + // Calculate the index of the node from the node level and the index at the current level. + val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + if (level < maxDepth - 1) { + val impurity = if (i == 0) { + nodeSplitStats._2.leftImpurity + } else { + nodeSplitStats._2.rightImpurity + } + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + // noting the parent impurities + parentImpurities(nodeIndex) = impurity + // noting the parents filters for the child nodes + val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) + filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) + for (filter <- filters(nodeIndex)) { + logDebug("Filter = " + filter) + } + } + } + } +} + +object DecisionTree extends Serializable with Logging { + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. The parameters for the algorithm are specified using the strategy parameter. + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int): DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The decision tree method supports binary classification and + * regression. For the binary classification, the label for each instance should either be 0 or + * 1 to denote the two classes. The method also supports categorical features inputs where the + * number of categories can specified using the categoricalFeaturesInfo option. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data for DecisionTree + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + private val InvalidBinIndex = -1 + + /** + * Returns an array of optimal splits for all nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ + protected[tree] def findBestSplits( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + + /* + * The high-level description for the best split optimizations are noted here. + * + * *Level-wise training* + * We perform bin calculations for all nodes at the given level to avoid making multiple + * passes over the data. Thus, for a slightly increased computation and storage cost we save + * several iterations over the data especially at higher levels of the decision tree. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * is ordered (read ordering for categorical variables in the findSplitsBins method), + * we exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // common calculations for multiple nested methods + val numNodes = scala.math.pow(2, level).toInt + logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. + val numFeatures = input.first().features.size + logDebug("numFeatures = " + numFeatures) + val numBins = bins(0).length + logDebug("numBins = " + numBins) + + /** Find the filters used before reaching the current code. */ + def findParentFilters(nodeIndex: Int): List[Filter] = { + if (level == 0) { + List[Filter]() + } else { + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + filters(nodeFilterIndex) + } + } + + /** + * Find whether the sample is valid input for the current node, i.e., whether it passes through + * all the filters for the current node. + */ + def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + // leaf + if ((level > 0) & (parentFilters.length == 0)) { + return false + } + + // Apply each filter and check sample validity. Return false when invalid condition found. + for (filter <- parentFilters) { + val features = labeledPoint.features + val featureIndex = filter.split.feature + val threshold = filter.split.threshold + val comparison = filter.comparison + val categories = filter.split.categories + val isFeatureContinuous = filter.split.featureType == Continuous + val feature = features(featureIndex) + if (isFeatureContinuous) { + comparison match { + case -1 => if (feature > threshold) return false + case 1 => if (feature <= threshold) return false + } + } else { + val containsFeature = categories.contains(feature) + comparison match { + case -1 => if (!containsFeature) return false + case 1 => if (containsFeature) return false + } + + } + } + + // Return true when the sample is valid for all filters. + true + } + + /** + * Find bin for one feature. + */ + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) & (highThreshold >= feature)){ + return mid + } + else if (lowThreshold >= feature) { + right = mid - 1 + } + else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature. + */ + def sequentialBinSearchForCategoricalFeature(): Int = { + val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val bin = bins(featureIndex)(binIndex) + val category = bin.category + val features = labeledPoint.features + if (category == features(featureIndex)) { + return binIndex + } + binIndex += 1 + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = sequentialBinSearchForCategoricalFeature() + if (binIndex == -1){ + throw new UnknownError("no bin was found for categorical variable.") + } + binIndex + } + } + + /** + * Finds bins for all nodes (and all features) at a given level. + * For l nodes, k features the storage is as follows: + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, + * where b_ij is an integer between 0 and numBins - 1. + * Invalid sample is denoted by noting bin for feature 1 as -1. + */ + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + // Calculate bin index and label per feature per node. + val arr = new Array[Double](1 + (numFeatures * numNodes)) + arr(0) = labeledPoint.label + var nodeIndex = 0 + while (nodeIndex < numNodes) { + val parentFilters = findParentFilters(nodeIndex) + // Find out whether the sample qualifies for the particular node. + val sampleValid = isSampleValid(parentFilters, labeledPoint) + val shift = 1 + numFeatures * nodeIndex + if (!sampleValid) { + // Mark one bin as -1 is sufficient. + arr(shift) = InvalidBinIndex + } else { + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + featureIndex += 1 + } + } + nodeIndex += 1 + } + arr + } + + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures * numNodes for classification + */ + def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = 2 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 + label match { + case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 + case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. + * + * @param agg Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for regression + */ + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update count, sum, and sum^2 for one bin. + val aggShift = 3 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition. + */ + def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { + strategy.algo match { + case Classification => classificationBinSeqOp(arr, agg) + case Regression => regressionBinSeqOp(arr, agg) + } + agg + } + + // Calculate bin aggregate length for classification or regression. + val binAggregateLength = strategy.algo match { + case Classification => 2 * numBins * numFeatures * numNodes + case Regression => 3 * numBins * numFeatures * numNodes + } + logDebug("binAggregateLength = " + binAggregateLength) + + /** + * Combines the aggregates from partitions. + * @param agg1 Array containing aggregates from one or more partitions + * @param agg2 Array containing aggregates from one or more partitions + * @return Combined aggregate from agg1 and agg2 + */ + def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { + var index = 0 + val combinedAggregate = new Array[Double](binAggregateLength) + while (index < binAggregateLength) { + combinedAggregate(index) = agg1(index) + agg2(index) + index += 1 + } + combinedAggregate + } + + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + // Calculate bin aggregates. + val binAggregates = { + binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + } + logDebug("binAggregates.length = " + binAggregates.length) + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + def calculateGainForSplit( + leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double): InformationGainStats = { + strategy.algo match { + case Classification => + val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) + val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) + val leftCount = left0Count + left1Count + + val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) + val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) + val rightCount = right0Count + right1Count + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + } + + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (left1Count + right1Count) / (leftCount + rightCount) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + case Regression => + val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) + val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + + val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) + val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + strategy.impurity.calculate(count, sum, sumSquares) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity ,topImpurity, + Double.MinValue, leftSum / leftCount) + } + + val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + } + } + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], + * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) + */ + def extractLeftRightNodeAggregates( + binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + strategy.algo match { + case Classification => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(2 * (numBins - 2)) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = + binData(shift + (2 *(numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + case Regression => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(3 * (numBins - 2)) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3) + leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) + leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } + } + + /** + * Calculates information gain for all nodes splits. + */ + def calculateGainsForAllNodeSplits( + leftNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Double]], + nodeImpurity: Double): Array[Array[InformationGainStats]] = { + val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) + + for (featureIndex <- 0 until numFeatures) { + for (splitIndex <- 0 until numBins - 1) { + gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, + splitIndex, rightNodeAgg, nodeImpurity) + } + } + gains + } + + /** + * Find the best split for a node. + * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param nodeImpurity impurity of the top node + * @return tuple of split and information gain + */ + def binsToBestSplit( + binData: Array[Double], + nodeImpurity: Double): (Split, InformationGainStats) = { + + logDebug("node impurity = " + nodeImpurity) + + // Extract left right node aggregates. + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + + // Calculate gains for all splits. + val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) + + val (bestFeatureIndex,bestSplitIndex, gainStats) = { + // Initialize with infeasible values. + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + // Iterate over features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Iterate over all splits. + var splitIndex = 0 + while (splitIndex < numBins - 1) { + val gainStats = gains(featureIndex)(splitIndex) + if (gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + } + splitIndex += 1 + } + featureIndex += 1 + } + (bestFeatureIndex, bestSplitIndex, bestGainStats) + } + + logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) + + (splits(bestFeatureIndex)(bestSplitIndex), gainStats) + } + + /** + * Get bin data for one node. + */ + def getBinDataForNode(node: Int): Array[Double] = { + strategy.algo match { + case Classification => + val shift = 2 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) + binsForNode + case Regression => + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode + } + } + + // Calculate best splits for all nodes at a given level + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + // Iterating over all nodes at this level + var node = 0 + while (node < numNodes) { + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val binsForNode: Array[Double] = getBinDataForNode(node) + logDebug("nodeImpurityIndex = " + nodeImpurityIndex) + val parentNodeImpurity = parentImpurities(nodeImpurityIndex) + logDebug("node impurity = " + parentNodeImpurity) + bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) + node += 1 + } + + bestSplits + } + + /** + * Returns split and bins for decision tree calculation. + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree + * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + */ + protected[tree] def findSplitsBins( + input: RDD[LabeledPoint], + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() + + // Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.size + + val maxBins = strategy.maxBins + val numBins = if (maxBins <= count) maxBins else count.toInt + logDebug("numBins = " + numBins) + + /* + * TODO: Add a require statement ensuring #bins is always greater than the categories. + * It's a limitation of the current implementation but a reasonable trade-off since features + * with large number of categories get favored over continuous features. + */ + if (strategy.categoricalFeaturesInfo.size > 0) { + val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + require(numBins >= maxCategoriesForFeatures) + } + + // Calculate the number of sample for approximate quantile calculation. + val requiredSamples = numBins*numBins + val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 + logDebug("fraction of data used for calculating quantiles = " + fraction) + + // sampled input for RDD calculation + val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val numSamples = sampledInput.length + + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + + strategy.quantileCalculationStrategy match { + case Sort => + val splits = Array.ofDim[Split](numFeatures, numBins - 1) + val bins = Array.ofDim[Bin](numFeatures, numBins) + + // Find all splits. + + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures){ + // Check whether the feature is continuous. + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins - 1) { + val sampleIndex = (index + 1) * stride.toInt + val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + splits(featureIndex)(index) = split + } + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + require(maxFeatureValue < numBins, "number of categories should be less than number " + + "of bins") + + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, + categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } + } + featureIndex += 1 + } + + // Find all bins. + featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { // Bins for categorical variables are already assigned. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) + for (index <- 1 until numBins - 1){ + val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Continuous, Double.MinValue) + bins(featureIndex)(index) = bin + } + bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) + } + featureIndex += 1 + } + (splits,bins) + case MinMax => + throw new UnsupportedOperationException("minmax not supported yet.") + case ApproxHist => + throw new UnsupportedOperationException("approximate histogram not supported yet.") + } + } + + val usage = """ + Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] + """ + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + val sc = new SparkContext(args(0), "DecisionTree") + + val argList = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]): OptionMap = { + list match { + case Nil => map + case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) + , tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), + tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => logError("Unknown option " + option) + sys.exit(1) + } + } + val options = nextOption(Map(), argList) + logDebug(options.toString()) + + // Load training data. + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + // Identify the type of algorithm. + val algoStr = options.get('algo).get.toString + val algo = algoStr match { + case "Classification" => Classification + case "Regression" => Regression + } + + // Identify the type of impurity. + val impurityStr = options.getOrElse('impurity, + if (algo == Classification) "Gini" else "Variance").toString + val impurity = impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + + val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt + val maxBins = options.getOrElse('maxBins, "100").toString.toInt + + val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val model = DecisionTree.train(trainData, strategy) + + // Load test data. + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + + // Measure algorithm accuracy + if (algo == Classification) { + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + } + + if (algo == Regression) { + val mse = meanSquaredError(model, testData) + logDebug("mean square error = " + mse) + } + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ..., + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble)) + LabeledPoint(label, features) + } + } + + // TODO: Port this method to a generic metrics package. + /** + * Calculates the classifier accuracy. + */ + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], + threshold: Double = 0.5): Double = { + def predictedValue(features: Vector) = { + if (model.predict(features) < threshold) 0.0 else 1.0 + } + val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + // TODO: Port this method to a generic metrics package + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md new file mode 100644 index 0000000000000..0fd71aa9735bc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md @@ -0,0 +1,17 @@ +This package contains the default implementation of the decision tree algorithm. + +The decision tree algorithm supports: ++ Binary classification ++ Regression ++ Information loss calculation with entropy and gini for classification and variance for regression ++ Both continuous and categorical features + +# Tree improvements ++ Node model pruning ++ Printing to dot files + +# Future Ensemble Extensions + ++ Random forests ++ Boosting ++ Extremely randomized trees diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala new file mode 100644 index 0000000000000..2dd1f0f27b8f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum to select the algorithm for the decision tree + */ +object Algo extends Enumeration { + type Algo = Value + val Classification, Regression = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala new file mode 100644 index 0000000000000..09ee0586c58fa --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum to describe whether a feature is "continuous" or "categorical" + */ +object FeatureType extends Enumeration { + type FeatureType = Value + val Continuous, Categorical = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala new file mode 100644 index 0000000000000..2457a480c2a14 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum for selecting the quantile calculation strategy + */ +object QuantileStrategy extends Enumeration { + type QuantileStrategy = Value + val Sort, MinMax, ApproxHist = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala new file mode 100644 index 0000000000000..df565f3eb8859 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ + +/** + * Stores all the configuration options for tree construction + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ +class Strategy ( + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala new file mode 100644 index 0000000000000..b93995fcf9441 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during + * binary classification. + */ +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + /** + * entropy calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return entropy value + */ + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Entropy.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala new file mode 100644 index 0000000000000..c0407554a91b3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating the + * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] + * during binary classification. + */ +object Gini extends Impurity { + + /** + * Gini coefficient calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return Gini coefficient value + */ + override def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0 * f0 - f1 * f1 + } + } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Gini.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala new file mode 100644 index 0000000000000..a4069063af2ad --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Trait for calculating information gain. + */ +trait Impurity extends Serializable { + + /** + * information calculation for binary classification + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return information value + */ + def calculate(c0 : Double, c1 : Double): Double + + /** + * information calculation for regression + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value + */ + def calculate(count: Double, sum: Double, sumSquares: Double): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala new file mode 100644 index 0000000000000..b74577dcec167 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating variance during regression + */ +object Variance extends Impurity { + override def calculate(c0: Double, c1: Double): Double = + throw new UnsupportedOperationException("Variance.calculate") + + /** + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum * sum) / count + squaredLoss / count + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala new file mode 100644 index 0000000000000..a57faa13745f7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +/** + * Used for "binning" the features bins for faster best split calculation. For a continuous + * feature, a bin is determined by a low and a high "split". For a categorical feature, + * the a bin is determined using a single label value (category). + * @param lowSplit signifying the lower threshold for the continuous feature to be + * accepted in the bin + * @param highSplit signifying the upper threshold for the continuous feature to be + * accepted in the bin + * @param featureType type of feature -- categorical or continuous + * @param category categorical label value accepted in the bin + */ +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala new file mode 100644 index 0000000000000..a6dca84a2ce09 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector + +/** + * Model to store the decision tree parameters + * @param topNode root node + * @param algo algorithm type -- classification or regression + */ +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features: Vector): Double = { + topNode.predictIfLeaf(features) + } + + /** + * Predict values for the given data set using the model trained. + * + * @param features RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = { + features.map(x => predict(x)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala new file mode 100644 index 0000000000000..ebc9595eafef3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Filter specifying a split and type of comparison to be applied on features + * @param split split specifying the feature index, type and threshold + * @param comparison integer specifying <,=,> + */ +case class Filter(split: Split, comparison: Int) { + // Comparison -1,0,1 signifies <.=,> + override def toString = " split = " + split + "comparison = " + comparison +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala new file mode 100644 index 0000000000000..99bf79cf12e45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Information gain statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param leftImpurity left node impurity + * @param rightImpurity right node impurity + * @param predict predicted value + */ +class InformationGainStats( + val gain: Double, + val impurity: Double, + val leftImpurity: Double, + val rightImpurity: Double, + val predict: Double) extends Serializable { + + override def toString = { + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala new file mode 100644 index 0000000000000..aac3f9ce308f7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.Logging +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.linalg.Vector + +/** + * Node in a decision tree + * @param id integer node id + * @param predict predicted value at the node + * @param isLeaf whether the leaf is a node + * @param split split to calculate left and right nodes + * @param leftNode left child + * @param rightNode right child + * @param stats information gain stats + */ +class Node ( + val id: Int, + val predict: Double, + val isLeaf: Boolean, + val split: Option[Split], + var leftNode: Option[Node], + var rightNode: Option[Node], + val stats: Option[InformationGainStats]) extends Serializable with Logging { + + override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + + "split = " + split + ", stats = " + stats + + /** + * build the left node and right nodes if not leaf + * @param nodes array of nodes + */ + def build(nodes: Array[Node]): Unit = { + + logDebug("building node " + id + " at level " + + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("id = " + id + ", split = " + split) + logDebug("stats = " + stats) + logDebug("predict = " + predict) + if (!isLeaf) { + val leftNodeIndex = id * 2 + 1 + val rightNodeIndex = id * 2 + 2 + leftNode = Some(nodes(leftNodeIndex)) + rightNode = Some(nodes(rightNodeIndex)) + leftNode.get.build(nodes) + rightNode.get.build(nodes) + } + } + + /** + * predict value if node is not leaf + * @param feature feature value + * @return predicted value + */ + def predictIfLeaf(feature: Vector) : Double = { + if (isLeaf) { + predict + } else{ + if (split.get.featureType == Continuous) { + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } else { + if (split.get.categories.contains(feature(split.get.feature))) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala new file mode 100644 index 0000000000000..4e64a81dda74e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType + +/** + * Split applied to a feature + * @param feature feature index + * @param threshold threshold for continuous feature + * @param featureType type of feature -- categorical or continuous + * @param categories accepted values for categorical variables + */ +case class Split( + feature: Int, + threshold: Double, + featureType: FeatureType, + categories: List[Double]){ + + override def toString = + "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + + ", categories = " + categories +} + +/** + * Split with minimum threshold for continuous features. Helps with the smallest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyLowSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MinValue, featureType, List()) + +/** + * Split with maximum threshold for continuous features. Helps with the highest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyHighSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + +/** + * Split with no acceptable feature values for categorical features. Helps with the first bin + * creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyCategoricalSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 2e03684e62861..81e4eda2a68c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -24,6 +24,7 @@ import org.jblas.DoubleMatrix import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint /** @@ -74,7 +75,7 @@ object LinearDataGenerator { val y = x.map { xi => new DoubleMatrix(1, xi.length, xi: _*).dot(weightsMat) + intercept + eps * rnd.nextGaussian() } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) + y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 52c4a71d621a1..61498dcc2be00 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors /** * Generate test data for LogisticRegression. This class chooses positive labels @@ -54,7 +55,7 @@ object LogisticRegressionDataGenerator { val x = Array.fill[Double](nfeatures) { rnd.nextGaussian() + (y * eps) } - LabeledPoint(y, x) + LabeledPoint(y, Vectors.dense(x)) } data } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 08cd9ab05547b..cb85e433bfc73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,15 +17,13 @@ package org.apache.spark.mllib.util +import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, + squaredDistance => breezeSquaredDistance} + import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ - -import org.jblas.DoubleMatrix - import org.apache.spark.mllib.regression.LabeledPoint - -import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -40,6 +38,107 @@ object MLUtils { eps } + /** + * Multiclass label parser, which parses a string into double. + */ + val multiclassLabelParser: String => Double = _.toDouble + + /** + * Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5, + * or 0.0 (negative) otherwise. + */ + val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0 + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint]. + * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. + * Each line represents a labeled sparse feature vector using the following format: + * {{{label index1:value1 index2:value2 ...}}} + * where the indices are one-based and in ascending order. + * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], + * where the feature indices are converted to zero-based. + * + * @param sc Spark context + * @param path file or directory path in any Hadoop-supported file system URI + * @param labelParser parser for labels, default: 1.0 if label > 0.5 or 0.0 otherwise + * @param numFeatures number of features, which will be determined from the input data if a + * negative value is given. The default value is -1. + * @param minSplits min number of partitions, default: sc.defaultMinSplits + * @return labeled data stored as an RDD[LabeledPoint] + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: String => Double, + numFeatures: Int, + minSplits: Int): RDD[LabeledPoint] = { + val parsed = sc.textFile(path, minSplits) + .map(_.trim) + .filter(!_.isEmpty) + .map(_.split(' ')) + // Determine number of features. + val d = if (numFeatures >= 0) { + numFeatures + } else { + parsed.map { items => + if (items.length > 1) { + items.last.split(':')(0).toInt + } else { + 0 + } + }.reduce(math.max) + } + parsed.map { items => + val label = labelParser(items.head) + val (indices, values) = items.tail.map { item => + val indexAndValue = item.split(':') + val index = indexAndValue(0).toInt - 1 + val value = indexAndValue(1).toDouble + (index, value) + }.unzip + LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray)) + } + } + + // Convenient methods for calling from Java. + + /** + * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with number of features determined automatically and the default number of partitions. + */ + def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] = + loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits) + + /** + * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with number of features specified explicitly and the default number of partitions. + */ + def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits) + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with the given label parser, number of features determined automatically, + * and the default number of partitions. + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: String => Double): RDD[LabeledPoint] = + loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits) + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with the given label parser, number of features specified explicitly, + * and the default number of partitions. + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: String => Double, + numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits) + /** * Load labeled data from a file. The data format used here is * , ... @@ -54,7 +153,7 @@ object MLUtils { sc.textFile(dir).map { line => val parts = line.split(',') val label = parts(0).toDouble - val features = parts(1).trim().split(' ').map(_.toDouble) + val features = Vectors.dense(parts(1).trim().split(' ').map(_.toDouble)) LabeledPoint(label, features) } } @@ -68,7 +167,7 @@ object MLUtils { * @param dir Directory to save the data. */ def saveLabeledData(data: RDD[LabeledPoint], dir: String) { - val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) + val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) dataStr.saveAsTextFile(dir) } @@ -76,44 +175,52 @@ object MLUtils { * Utility function to compute mean and standard deviation on a given dataset. * * @param data - input data set whose statistics are computed - * @param nfeatures - number of features - * @param nexamples - number of examples in input dataset + * @param numFeatures - number of features + * @param numExamples - number of examples in input dataset * * @return (yMean, xColMean, xColSd) - Tuple consisting of * yMean - mean of the labels * xColMean - Row vector with mean for every column (or feature) of the input data * xColSd - Row vector standard deviation for every column (or feature) of the input data. */ - def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): - (Double, DoubleMatrix, DoubleMatrix) = { - val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples - - // NOTE: We shuffle X by column here to compute column sum and sum of squares. - val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => - val nCols = labeledPoint.features.length - // Traverse over every column and emit (col, value, value^2) - Iterator.tabulate(nCols) { i => - (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) - } - }.reduceByKey { case(x1, x2) => - (x1._1 + x2._1, x1._2 + x2._2) + def computeStats( + data: RDD[LabeledPoint], + numFeatures: Int, + numExamples: Long): (Double, Vector, Vector) = { + val brzData = data.map { case LabeledPoint(label, features) => + (label, features.toBreeze) } - val xColSumsMap = xColSumSq.collectAsMap() - - val xColMean = DoubleMatrix.zeros(nfeatures, 1) - val xColSd = DoubleMatrix.zeros(nfeatures, 1) - - // Compute mean and unbiased variance using column sums - var col = 0 - while (col < nfeatures) { - xColMean.put(col, xColSumsMap(col)._1 / nexamples) - val variance = - (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / nexamples - xColSd.put(col, math.sqrt(variance)) - col += 1 + val aggStats = brzData.aggregate( + (0L, 0.0, BDV.zeros[Double](numFeatures), BDV.zeros[Double](numFeatures)) + )( + seqOp = (c, v) => (c, v) match { + case ((n, sumLabel, sum, sumSq), (label, features)) => + features.activeIterator.foreach { case (i, x) => + sumSq(i) += x * x + } + (n + 1L, sumLabel + label, sum += features, sumSq) + }, + combOp = (c1, c2) => (c1, c2) match { + case ((n1, sumLabel1, sum1, sumSq1), (n2, sumLabel2, sum2, sumSq2)) => + (n1 + n2, sumLabel1 + sumLabel2, sum1 += sum2, sumSq1 += sumSq2) + } + ) + val (nl, sumLabel, sum, sumSq) = aggStats + + require(nl > 0, "Input data is empty.") + require(nl == numExamples) + + val n = nl.toDouble + val yMean = sumLabel / n + val mean = sum / n + val std = new Array[Double](sum.length) + var i = 0 + while (i < numFeatures) { + std(i) = sumSq(i) / n - mean(i) * mean(i) + i += 1 } - (yMean, xColMean, xColSd) + (yMean, Vectors.fromBreeze(mean), Vectors.dense(std)) } /** @@ -144,6 +251,18 @@ object MLUtils { val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 val normDiff = norm1 - norm2 var sqDist = 0.0 + /* + * The relative error is + *
+     * EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
+     * 
+ * which is bounded by + *
+     * 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
+     * 
+ * The bound doesn't need the inner product, so we can use it as a sufficient condition to + * check quickly whether the inner product approach is accurate. + */ val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) if (precisionBound1 < precision) { sqDist = sumSquaredNorm - 2.0 * v1.dot(v2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index c96c94f70eef7..e300c3dbe1fe0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -23,6 +23,7 @@ import org.jblas.DoubleMatrix import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint /** @@ -58,7 +59,7 @@ object SVMDataGenerator { } val yD = new DoubleMatrix(1, x.length, x: _*).dot(trueWeights) + rnd.nextGaussian() * 0.1 val y = if (yD < 0) 0.0 else 1.0 - LabeledPoint(y, x) + LabeledPoint(y, Vectors.dense(x)) } MLUtils.saveLabeledData(data, outputPath) diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 073ded6f36933..c80b1134ed1b2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.junit.After; import org.junit.Assert; @@ -45,12 +46,12 @@ public void tearDown() { } private static final List POINTS = Arrays.asList( - new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}), - new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}), - new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}), - new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}), - new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}), - new LabeledPoint(2, new double[] {0.0, 0.0, 2.0}) + new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)), + new LabeledPoint(0, Vectors.dense(2.0, 0.0, 0.0)), + new LabeledPoint(1, Vectors.dense(0.0, 1.0, 0.0)), + new LabeledPoint(1, Vectors.dense(0.0, 2.0, 0.0)), + new LabeledPoint(2, Vectors.dense(0.0, 0.0, 1.0)), + new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0)) ); private int validatePrediction(List points, NaiveBayesModel model) { diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 117e5eaa8b78e..4701a5e545020 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.classification; - import java.io.Serializable; import java.util.List; @@ -28,7 +27,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; - import org.apache.spark.mllib.regression.LabeledPoint; public class JavaSVMSuite implements Serializable { @@ -94,5 +92,4 @@ public void runSVMUsingStaticMethods() { int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } - } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 2c4d795f96e4e..c6d8425ffc38d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -19,10 +19,10 @@ import java.io.Serializable; -import com.google.common.collect.Lists; - import scala.Tuple2; +import com.google.common.collect.Lists; + import org.junit.Test; import static org.junit.Assert.*; @@ -36,7 +36,7 @@ public void denseArrayConstruction() { @Test public void sparseArrayConstruction() { - Vector v = Vectors.sparse(3, Lists.newArrayList( + Vector v = Vectors.sparse(3, Lists.>newArrayList( new Tuple2(0, 2.0), new Tuple2(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java index f44b25cd44d19..f725924a2d971 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -59,7 +59,7 @@ int validatePrediction(List validationData, LassoModel model) { @Test public void runLassoUsingConstructor() { int nPoints = 10000; - double A = 2.0; + double A = 0.0; double[] weights = {-1.5, 1.0e-2}; JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, @@ -80,7 +80,7 @@ public void runLassoUsingConstructor() { @Test public void runLassoUsingStaticMethods() { int nPoints = 10000; - double A = 2.0; + double A = 0.0; double[] weights = {-1.5, 1.0e-2}; JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index 2fdd5fc8fdca6..03714ae7e4d00 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -55,30 +55,27 @@ public void tearDown() { return errorSum / validationData.size(); } - List generateRidgeData(int numPoints, int nfeatures, double eps) { + List generateRidgeData(int numPoints, int numFeatures, double std) { org.jblas.util.Random.seed(42); // Pick weights as random values distributed uniformly in [-0.5, 0.5] - DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5); - // Set first two weights to eps - w.put(0, 0, eps); - w.put(1, 0, eps); - return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps); + DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5); + return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std); } @Test public void runRidgeRegressionUsingConstructor() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); + int numExamples = 50; + int numFeatures = 20; + List data = generateRidgeData(2*numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); + JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); - ridgeSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.0) - .setNumIterations(200); + ridgeSGDImpl.optimizer() + .setStepSize(1.0) + .setRegParam(0.0) + .setNumIterations(200); RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); double unRegularizedErr = predictionError(validationData, model); @@ -91,13 +88,12 @@ public void runRidgeRegressionUsingConstructor() { @Test public void runRidgeRegressionUsingStaticMethods() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); + int numExamples = 50; + int numFeatures = 20; + List data = generateRidgeData(2 * numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); + JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); double unRegularizedErr = predictionError(validationData, model); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 05322b024d5f6..1e03c9df820b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification import scala.util.Random import scala.collection.JavaConversions._ -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext @@ -61,7 +60,7 @@ object LogisticRegressionSuite { if (yVal > 0) 1 else 0 } - val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i))))) testData } @@ -113,7 +112,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) val initialB = -1.0 - val initialWeights = Array(initialB) + val initialWeights = Vectors.dense(initialB) val testRDD = sc.parallelize(testData, 2) testRDD.cache() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 9dd6c79ee6ad8..516895d04222d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification import scala.util.Random -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.LocalSparkContext @@ -54,7 +54,7 @@ object NaiveBayesSuite { if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 } - LabeledPoint(y, xi) + LabeledPoint(y, Vectors.dense(xi)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index bc7abb568a172..dfacbfeee6fb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.classification import scala.util.Random import scala.collection.JavaConversions._ -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.jblas.DoubleMatrix @@ -28,6 +27,7 @@ import org.jblas.DoubleMatrix import org.apache.spark.SparkException import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.linalg.Vectors object SVMSuite { @@ -54,7 +54,7 @@ object SVMSuite { intercept + 0.01 * rnd.nextGaussian() if (yD < 0) 0.0 else 1.0 } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) + y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } } @@ -110,7 +110,7 @@ class SVMSuite extends FunSuite with LocalSparkContext { val initialB = -1.0 val initialC = -1.0 - val initialWeights = Array(initialB,initialC) + val initialWeights = Vectors.dense(initialB, initialC) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext { } intercept[SparkException] { - val model = SVMWithSGD.train(testRDDInvalid, 100) + SVMWithSGD.train(testRDDInvalid, 100) } // Turning off data validation should not throw an exception - val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) + new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 631d0e2ad9cdb..c4b433499a091 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random import scala.collection.JavaConversions._ -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.apache.spark.SparkContext import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.linalg.Vectors object GradientDescentSuite { @@ -58,8 +57,7 @@ object GradientDescentSuite { if (yVal > 0) 1 else 0 } - val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) - testData + (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(x1(i)))) } } @@ -83,11 +81,11 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Array(1.0, features: _*) + label -> Vectors.dense(1.0, features.toArray: _*) } val dataRDD = sc.parallelize(data, 2).cache() - val initialWeightsWithIntercept = Array(1.0, initialWeights: _*) + val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*) val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, @@ -113,13 +111,13 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Array(1.0, features: _*) + label -> Vectors.dense(1.0, features.toArray: _*) } val dataRDD = sc.parallelize(data, 2).cache() // Prepare non-zero weights - val initialWeightsWithIntercept = Array(1.0, 0.5) + val initialWeightsWithIntercept = Vectors.dense(1.0, 0.5) val regParam0 = 0 val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2cebac943e15f..6aad9eb84e13c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class LassoSuite extends FunSuite with LocalSparkContext { @@ -33,29 +34,33 @@ class LassoSuite extends FunSuite with LocalSparkContext { } test("Lasso local random SGD") { - val nPoints = 10000 + val nPoints = 1000 val A = 2.0 val B = -1.5 val C = 1.0e-2 - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } + val testRDD = sc.parallelize(testData, 2).cache() val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) val model = ls.run(testRDD) - val weight0 = model.weights(0) val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + val weight2 = model.weights(2) + assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]") + assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") + assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -66,33 +71,39 @@ class LassoSuite extends FunSuite with LocalSparkContext { } test("Lasso local random SGD with initial weights") { - val nPoints = 10000 + val nPoints = 1000 val A = 2.0 val B = -1.5 val C = 1.0e-2 - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } + val initialA = -1.0 val initialB = -1.0 val initialC = -1.0 - val initialWeights = Array(initialB,initialC) + val initialWeights = Vectors.dense(initialA, initialB, initialC) - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() + val testRDD = sc.parallelize(testData, 2).cache() val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) val model = ls.run(testRDD, initialWeights) - val weight0 = model.weights(0) val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + val weight2 = model.weights(2) + assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]") + assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") + assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } val validationRDD = sc.parallelize(validationData,2) // Test prediction on RDD. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 5d251bcbf35db..2f7d30708ce17 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class LinearRegressionSuite extends FunSuite with LocalSparkContext { @@ -40,11 +41,12 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { linReg.optimizer.setNumIterations(1000).setStepSize(1.0) val model = linReg.run(testRDD) - assert(model.intercept >= 2.5 && model.intercept <= 3.5) - assert(model.weights.length === 2) - assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) - assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val weights = model.weights + assert(weights.size === 2) + assert(weights(0) >= 9.0 && weights(0) <= 11.0) + assert(weights(1) >= 9.0 && weights(1) <= 11.0) val validationData = LinearDataGenerator.generateLinearInput( 3.0, Array(10.0, 10.0), 100, 17) @@ -67,9 +69,11 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { val model = linReg.run(testRDD) assert(model.intercept === 0.0) - assert(model.weights.length === 2) - assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) - assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val weights = model.weights + assert(weights.size === 2) + assert(weights(0) >= 9.0 && weights(0) <= 11.0) + assert(weights(1) >= 9.0 && weights(1) <= 11.0) val validationData = LinearDataGenerator.generateLinearInput( 0.0, Array(10.0, 10.0), 100, 17) @@ -81,4 +85,40 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + // Test if we can correctly learn Y = 10*X1 + 10*X10000 + test("sparse linear regression without intercept") { + val denseRDD = sc.parallelize( + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2) + val sparseRDD = denseRDD.map { case LabeledPoint(label, v) => + val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1)))) + LabeledPoint(label, sv) + }.cache() + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(sparseRDD) + + assert(model.intercept === 0.0) + + val weights = model.weights + assert(weights.size === 10000) + assert(weights(0) >= 9.0 && weights(0) <= 11.0) + assert(weights(9999) >= 9.0 && weights(9999) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17) + val sparseValidationData = validationData.map { case LabeledPoint(label, v) => + val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1)))) + LabeledPoint(label, sv) + } + val sparseValidationRDD = sc.parallelize(sparseValidationData, 2) + + // Test prediction on RDD. + validatePrediction( + model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData) + + // Test prediction on Array. + validatePrediction( + sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index b2044ed0d8066..f66fc6ea6c1ec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.regression -import org.jblas.DoubleMatrix import org.scalatest.FunSuite +import org.jblas.DoubleMatrix + import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class RidgeRegressionSuite extends FunSuite with LocalSparkContext { @@ -30,22 +31,22 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { }.reduceLeft(_ + _) / predictions.size } - test("regularization with skewed weights") { - val nexamples = 200 - val nfeatures = 20 - val eps = 10 + test("ridge regression can help avoid overfitting") { + + // For small number of examples and large variance of error distribution, + // ridge regression should give smaller generalization error that linear regression. + + val numExamples = 50 + val numFeatures = 20 org.jblas.util.Random.seed(42) // Pick weights as random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) - // Set first two weights to eps - w.put(0, 0, eps) - w.put(1, 0, eps) + val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5) // Use half of data for training and other half for validation - val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps) - val testData = data.take(nexamples) - val validationData = data.takeRight(nexamples) + val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0) + val testData = data.take(numExamples) + val validationData = data.takeRight(numExamples) val testRDD = sc.parallelize(testData, 2).cache() val validationRDD = sc.parallelize(validationData, 2).cache() @@ -67,7 +68,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { val ridgeErr = predictionError( ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) - // Ridge CV-error should be lower than linear regression + // Ridge validation error should be lower than linear regression. assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala new file mode 100644 index 0000000000000..350130c914f26 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.Filter +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.linalg.Vectors + +class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { + + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("split and bin calculation") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + } + + test("split and bin calculation for categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2) === null) + } + + test("split and bin calculations for categorical variables with no sample for one category") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(2.0)) + + assert(splits(0)(3) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 3) + assert(splits(1)(2).categories.contains(1.0)) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(2.0)) + + assert(splits(1)(3) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category === 2.0) + assert(bins(0)(2).lowSplit.categories.length === 2) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).lowSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.length === 3) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.contains(2.0)) + + assert(bins(0)(3) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category === 2.0) + assert(bins(1)(2).lowSplit.categories.length === 2) + assert(bins(1)(2).lowSplit.categories.contains(0.0)) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length === 3) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(2.0)) + + assert(bins(1)(3) === null) + } + + test("classification stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("regression stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("stump with fixed label 0 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + } + + test("stump with fixed label 1 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } + + test("stump with fixed label 0 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 0) + } + + test("stump with fixed label 1 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } +} + +object DecisionTreeSuite { + + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } + arr + } + + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) + arr(i) = lp + } + arr + } + + def generateCategoricalDataPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) + } else { + arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) + } + } + arr + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 60f053b381305..27d41c7869aa0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -17,14 +17,20 @@ package org.apache.spark.mllib.util +import java.io.File + import org.scalatest.FunSuite import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, squaredDistance => breezeSquaredDistance} +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ -class MLUtilsSuite extends FunSuite { +class MLUtilsSuite extends FunSuite with LocalSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") @@ -49,4 +55,55 @@ class MLUtilsSuite extends FunSuite { assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m") } } + + test("compute stats") { + val data = Seq.fill(3)(Seq( + LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)), + LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0)) + )).flatten + val rdd = sc.parallelize(data, 2) + val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6) + assert(meanLabel === 0.5) + assert(mean === Vectors.dense(2.0, 3.0, 4.0)) + assert(std === Vectors.dense(1.0, 1.0, 1.0)) + } + + test("loadLibSVMData") { + val lines = + """ + |+1 1:1.0 3:2.0 5:3.0 + |-1 + |-1 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Files.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect() + val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect() + + for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { + assert(points.length === 3) + assert(points(0).label === 1.0) + assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + assert(points(1).label == 0.0) + assert(points(1).features == Vectors.sparse(6, Seq())) + assert(points(2).label === 0.0) + assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) + } + + val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect() + assert(multiclassPoints.length === 3) + assert(multiclassPoints(0).label === 1.0) + assert(multiclassPoints(1).label === -1.0) + assert(multiclassPoints(2).label === -1.0) + + try { + file.delete() + tempDir.delete() + } catch { + case t: Throwable => + } + } } diff --git a/pom.xml b/pom.xml index 72acf2b402703..b91b14d2f84d0 100644 --- a/pom.xml +++ b/pom.xml @@ -110,7 +110,7 @@ 1.6 - 2.10.3 + 2.10.4 2.10 0.13.0 org.spark-project.akka @@ -192,33 +192,28 @@ org.eclipse.jetty jetty-util - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-security - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-plus - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-server - 7.6.8.v20121106 + 8.1.14.v20131031 com.google.guava guava 14.0.1 - - com.google.code.findbugs - jsr305 - 1.3.9 - org.slf4j slf4j-api @@ -380,7 +375,7 @@ lift-json_${scala.binary.version} 2.5.1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9e269e6551341..a2a21d9763548 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -17,7 +17,7 @@ import sbt._ import sbt.Classpaths.publishTask -import Keys._ +import sbt.Keys._ import sbtassembly.Plugin._ import AssemblyKeys._ import scala.util.Properties @@ -27,7 +27,7 @@ import com.typesafe.tools.mima.plugin.MimaKeys.previousArtifact import scala.collection.JavaConversions._ // For Sonatype publishing -//import com.jsuereth.pgp.sbtplugin.PgpKeys._ +// import com.jsuereth.pgp.sbtplugin.PgpKeys._ object SparkBuild extends Build { val SPARK_VERSION = "1.0.0-SNAPSHOT" @@ -152,7 +152,7 @@ object SparkBuild extends Build { def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq( organization := "org.apache.spark", version := SPARK_VERSION, - scalaVersion := "2.10.3", + scalaVersion := "2.10.4", scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-target:" + SCALAC_JVM_VERSION), javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION), @@ -200,7 +200,7 @@ object SparkBuild extends Build { publishMavenStyle := true, - //useGpg in Global := true, + // useGpg in Global := true, pomExtra := ( @@ -248,13 +248,13 @@ object SparkBuild extends Build { */ libraryDependencies ++= Seq( - "io.netty" % "netty-all" % "4.0.17.Final", - "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106", + "io.netty" % "netty-all" % "4.0.17.Final", + "org.eclipse.jetty" % "jetty-server" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-util" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-plus" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-security" % "8.1.14.v20131031", /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */ - "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"), + "org.eclipse.jetty.orbit" % "javax.servlet" % "3.0.0.v201112011016" artifacts Artifact("javax.servlet", "jar", "jar"), "org.scalatest" %% "scalatest" % "1.9.1" % "test", "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", "com.novocode" % "junit-interface" % "0.10" % "test", @@ -296,7 +296,6 @@ object SparkBuild extends Build { name := "spark-core", libraryDependencies ++= Seq( "com.google.guava" % "guava" % "14.0.1", - "com.google.code.findbugs" % "jsr305" % "1.3.9", "log4j" % "log4j" % "1.2.17", "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion, diff --git a/project/plugins.sbt b/project/plugins.sbt index 4ff6f67af45c0..d787237ddc540 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,4 +1,4 @@ -scalaVersion := "2.10.3" +scalaVersion := "2.10.4" resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) @@ -22,3 +22,4 @@ addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.0") + diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala new file mode 100644 index 0000000000000..0142256e90fb7 --- /dev/null +++ b/project/project/SparkPluginBuild.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import sbt._ +import sbt.Keys._ + +/** + * This plugin project is there to define new scala style rules for spark. This is + * a plugin project so that this gets compiled first and is put on the classpath and + * becomes available for scalastyle sbt plugin. + */ +object SparkPluginDef extends Build { + lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle) + lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) + val sparkVersion = "1.0.0-SNAPSHOT" + // There is actually no need to publish this artifact. + def styleSettings = Defaults.defaultSettings ++ Seq ( + name := "spark-style", + organization := "org.apache.spark", + version := sparkVersion, + scalaVersion := "2.10.4", + scalacOptions := Seq("-unchecked", "-deprecation"), + libraryDependencies ++= Dependencies.scalaStyle + ) + + object Dependencies { + val scalaStyle = Seq("org.scalastyle" %% "scalastyle" % "0.4.0") + } +} diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala new file mode 100644 index 0000000000000..80d3faa3fe749 --- /dev/null +++ b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.scalastyle + +import java.util.regex.Pattern + +import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} +import scalariform.lexer.{MultiLineComment, ScalaDocComment, SingleLineComment, Token} +import scalariform.parser.CompilationUnit + +class SparkSpaceAfterCommentStartChecker extends ScalariformChecker { + val errorKey: String = "insert.a.single.space.after.comment.start.and.before.end" + + private def multiLineCommentRegex(comment: Token) = + Pattern.compile( """/\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || + Pattern.compile( """/\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() + + private def scalaDocPatternRegex(comment: Token) = + Pattern.compile( """/\*\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || + Pattern.compile( """/\*\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() + + private def singleLineCommentRegex(comment: Token): Boolean = + comment.text.trim.matches( """//\S+.*""") && !comment.text.trim.matches( """///+""") + + override def verify(ast: CompilationUnit): List[ScalastyleError] = { + ast.tokens + .filter(hasComment) + .map { + _.associatedWhitespaceAndComments.comments.map { + case x: SingleLineComment if singleLineCommentRegex(x.token) => Some(x.token.offset) + case x: MultiLineComment if multiLineCommentRegex(x.token) => Some(x.token.offset) + case x: ScalaDocComment if scalaDocPatternRegex(x.token) => Some(x.token.offset) + case _ => None + }.flatten + }.flatten.map(PositionError(_)) + } + + + private def hasComment(x: Token) = + x.associatedWhitespaceAndComments != null && !x.associatedWhitespaceAndComments.comments.isEmpty + +} diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 19b90dfd6e167..d2f9cdb3f4298 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -87,18 +87,19 @@ class NaiveBayesModel(object): >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3) >>> model = NaiveBayes.train(sc.parallelize(data)) >>> model.predict(array([0.0, 1.0])) - 0 + 0.0 >>> model.predict(array([1.0, 0.0])) - 1 + 1.0 """ - def __init__(self, pi, theta): + def __init__(self, labels, pi, theta): + self.labels = labels self.pi = pi self.theta = theta def predict(self, x): """Return the most likely class for a data vector x""" - return numpy.argmax(self.pi + dot(x, self.theta)) + return self.labels[numpy.argmax(self.pi + dot(x, self.theta))] class NaiveBayes(object): @classmethod @@ -122,7 +123,8 @@ def train(cls, data, lambda_=1.0): ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_) return NaiveBayesModel( _deserialize_double_vector(ans[0]), - _deserialize_double_matrix(ans[1])) + _deserialize_double_vector(ans[1]), + _deserialize_double_matrix(ans[2])) def _test(): diff --git a/python/run-tests b/python/run-tests index a986ac9380be4..b2b60f08b48e2 100755 --- a/python/run-tests +++ b/python/run-tests @@ -29,8 +29,18 @@ FAILED=0 rm -f unit-tests.log function run_test() { - SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log FAILED=$((PIPESTATUS[0]||$FAILED)) + + # Fail and exit on the first test failure. + if [[ $FAILED != 0 ]]; then + cat unit-tests.log | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. + echo -en "\033[31m" # Red + echo "Had test failures; see logs." + echo -en "\033[0m" # No color + exit -1 + fi + } run_test "pyspark/rdd.py" @@ -46,12 +56,7 @@ run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" -if [[ $FAILED != 0 ]]; then - echo -en "\033[31m" # Red - echo "Had test failures; see logs." - echo -en "\033[0m" # No color - exit -1 -else +if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green echo "Tests passed." echo -en "\033[0m" # No color diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index ee972887feda6..bf73800388ebf 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -124,8 +124,8 @@ extends ClassVisitor(ASM4, cv) { mv.visitVarInsn(ALOAD, 0) // load this mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V") mv.visitVarInsn(ALOAD, 0) // load this - //val classType = className.replace('.', '/') - //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") + // val classType = className.replace('.', '/') + // mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";") mv.visitInsn(RETURN) mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed mv.visitEnd() diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 90a96ad38381e..fa2f1a88c4eb5 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -834,7 +834,7 @@ import org.apache.spark.util.Utils } ((pos, msg)) :: loop(filtered) } - //PRASHANT: This leads to a NoSuchMethodError for _.warnings. Yet to figure out its purpose. + // PRASHANT: This leads to a NoSuchMethodError for _.warnings. Yet to figure out its purpose. // val warnings = loop(run.allConditionalWarnings flatMap (_.warnings)) // if (warnings.nonEmpty) // mostRecentWarnings = warnings diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 8203b8f6122e1..4155007c6d337 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -242,4 +242,15 @@ class ReplSuite extends FunSuite { assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) } } + + test("collecting objects of class defined in repl") { + val output = runInterpreter("local[2]", + """ + |case class Foo(i: Int) + |val ret = sc.parallelize((1 to 100).map(Foo), 10).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("ret: Array[Foo] = Array(Foo(1),", output) + } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index ee968c53b3e4b..76ba1ecca33ab 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -140,4 +140,5 @@ + diff --git a/sql/README.md b/sql/README.md index 4192fecb92fb0..14d5555f0c713 100644 --- a/sql/README.md +++ b/sql/README.md @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.TestHive._ -Welcome to Scala version 2.10.3 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). +Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). Type in expressions to have them evaluated. Type :help for more information. diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 740f1fdc83299..0edce55a93338 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -31,6 +31,18 @@ Spark Project Catalyst http://spark.apache.org/ + + + yarn-alpha + + + org.apache.avro + avro + + + + + org.apache.spark diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 976dda8d7e59a..5aaa63bf3b4b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -43,15 +43,25 @@ object ScalaReflection { val params = t.member("": TermName).asMethod.paramss StructType( params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true))) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => BinaryType + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t ArrayType(schemaFor(elementType)) + case t if t <:< typeOf[Map[_,_]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + MapType(schemaFor(keyType), schemaFor(valueType)) case t if t <:< typeOf[String] => StringType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType + case t if t <:< definitions.FloatTpe => FloatType case t if t <:< definitions.DoubleTpe => DoubleType case t if t <:< definitions.ShortTpe => ShortType case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType + case t if t <:< typeOf[BigDecimal] => DecimalType } implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 9dec4e3d9e4c2..8de87594c8ab9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -114,6 +114,9 @@ class SqlParser extends StandardTokenParsers { protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") protected val OR = Keyword("OR") + protected val LIKE = Keyword("LIKE") + protected val RLIKE = Keyword("RLIKE") + protected val REGEXP = Keyword("REGEXP") protected val ORDER = Keyword("ORDER") protected val OUTER = Keyword("OUTER") protected val RIGHT = Keyword("RIGHT") @@ -178,7 +181,7 @@ class SqlParser extends StandardTokenParsers { val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) - val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder) + val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder) withLimit } @@ -267,6 +270,9 @@ class SqlParser extends StandardTokenParsers { termExpression ~ ">=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => GreaterThanOrEqual(e1, e2) } | termExpression ~ "!=" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | termExpression ~ "<>" ~ termExpression ^^ { case e1 ~ _ ~ e2 => Not(Equals(e1, e2)) } | + termExpression ~ RLIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | + termExpression ~ REGEXP ~ termExpression ^^ { case e1 ~ _ ~ e2 => RLike(e1, e2) } | + termExpression ~ LIKE ~ termExpression ^^ { case e1 ~ _ ~ e2 => Like(e1, e2) } | termExpression ~ IN ~ "(" ~ rep1sep(termExpression, ",") <~ ")" ^^ { case e1 ~ _ ~ _ ~ e2 => In(e1, e2) } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index ff66177a03b8c..6b58b9322c4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -31,6 +31,7 @@ trait Catalog { alias: Option[String] = None): LogicalPlan def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit + def unregisterTable(databaseName: Option[String], tableName: String): Unit } class SimpleCatalog extends Catalog { @@ -40,17 +41,18 @@ class SimpleCatalog extends Catalog { tables += ((tableName, plan)) } - def dropTable(tableName: String) = tables -= tableName + def unregisterTable(databaseName: Option[String], tableName: String) = { tables -= tableName } def lookupRelation( databaseName: Option[String], tableName: String, alias: Option[String] = None): LogicalPlan = { val table = tables.get(tableName).getOrElse(sys.error(s"Table Not Found: $tableName")) + val tableWithQualifiers = Subquery(tableName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a.toLowerCase, table)).getOrElse(table) + alias.map(a => Subquery(a.toLowerCase, tableWithQualifiers)).getOrElse(tableWithQualifiers) } } @@ -86,6 +88,10 @@ trait OverrideCatalog extends Catalog { plan: LogicalPlan): Unit = { overrides.put((databaseName, tableName), plan) } + + override def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + overrides.remove((databaseName, tableName)) + } } /** @@ -103,4 +109,8 @@ object EmptyCatalog extends Catalog { def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } + + def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + throw new UnsupportedOperationException + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 67cddb351c185..44abe671c07a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -70,6 +70,9 @@ package object dsl { def === (other: Expression) = Equals(expr, other) def != (other: Expression) = Not(Equals(expr, other)) + def like(other: Expression) = Like(expr, other) + def rlike(other: Expression) = RLike(expr, other) + def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) @@ -90,7 +93,10 @@ package object dsl { implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } - implicit class DslString(val s: String) extends ImplicitAttribute + implicit class DslString(val s: String) extends ImplicitOperators { + def expr: Expression = Literal(s) + def attr = analysis.UnresolvedAttribute(s) + } abstract class ImplicitAttribute extends ImplicitOperators { def s: String @@ -110,6 +116,8 @@ package object dsl { // Protobuf terminology def required = a.withNullability(false) + + def at(ordinal: Int) = BoundReference(ordinal, a) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 31d42b9ee71a0..6f939e6c41f6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable { s"[${this.mkString(",")}]" def copy(): Row + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + var i = 0 + while (i < length) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 722ff517d250e..02fedd16b8d4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.types.{BooleanType, StringType} +object InterpretedPredicate { + def apply(expression: Expression): (Row => Boolean) = { + (r: Row) => expression.apply(r).asInstanceOf[Boolean] + } +} + trait Predicate extends Expression { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index e195f2ac7efd1..42b7a9b125b7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -17,11 +17,103 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.regex.Pattern + +import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.catalyst.types.BooleanType +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.errors.`package`.TreeNodeException + + +trait StringRegexExpression { + self: BinaryExpression => + + type EvaluatedType = Any + + def escape(v: String): String + def matches(regex: Pattern, str: String): Boolean + + def nullable: Boolean = true + def dataType: DataType = BooleanType + + // try cache the pattern for Literal + private lazy val cache: Pattern = right match { + case x @ Literal(value: String, StringType) => compile(value) + case _ => null + } + + protected def compile(str: String): Pattern = if(str == null) { + null + } else { + // Let it raise exception if couldn't compile the regex string + Pattern.compile(escape(str)) + } -case class Like(left: Expression, right: Expression) extends BinaryExpression { - def dataType = BooleanType - def nullable = left.nullable // Right cannot be null. + protected def pattern(str: String) = if(cache == null) compile(str) else cache + + override def apply(input: Row): Any = { + val l = left.apply(input) + if(l == null) { + null + } else { + val r = right.apply(input) + if(r == null) { + null + } else { + val regex = pattern(r.asInstanceOf[String]) + if(regex == null) { + null + } else { + matches(regex, l.asInstanceOf[String]) + } + } + } + } +} + +/** + * Simple RegEx pattern matching function + */ +case class Like(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression { + def symbol = "LIKE" + + // replace the _ with .{1} exactly match 1 time of any character + // replace the % with .*, match 0 or more times with any character + override def escape(v: String) = { + val sb = new StringBuilder() + var i = 0; + while (i < v.length) { + // Make a special case for "\\_" and "\\%" + val n = v.charAt(i); + if (n == '\\' && i + 1 < v.length && (v.charAt(i + 1) == '_' || v.charAt(i + 1) == '%')) { + sb.append(v.charAt(i + 1)) + i += 1 + } else { + if (n == '_') { + sb.append("."); + } else if (n == '%') { + sb.append(".*"); + } else { + sb.append(Pattern.quote(Character.toString(n))); + } + } + + i += 1 + } + + sb.toString() + } + + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() } +case class RLike(left: Expression, right: Expression) + extends BinaryExpression with StringRegexExpression { + + def symbol = "RLIKE" + override def escape(v: String): String = v + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 9d16189deedfe..b39c2b32cc42c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -130,7 +130,7 @@ case class Aggregate( def references = child.references } -case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode { +case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode { def output = child.output def references = limit.references } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 94894adf81202..52a205be3e9f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -109,4 +109,87 @@ class ExpressionEvaluationSuite extends FunSuite { } } } + + def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { + expression.apply(inputRow) + } + + def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + test("LIKE literal Regular Expression") { + checkEvaluation(Literal(null, StringType).like("a"), null) + checkEvaluation(Literal(null, StringType).like(Literal(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) + checkEvaluation("a_%b" like "a\\__b", true) + checkEvaluation("addb" like "a_%b", true) + checkEvaluation("addb" like "a\\__b", false) + checkEvaluation("addb" like "a%\\%b", false) + checkEvaluation("a_%b" like "a%\\%b", true) + checkEvaluation("addb" like "a%", true) + checkEvaluation("addb" like "**", false) + checkEvaluation("abc" like "a%", true) + checkEvaluation("abc" like "b%", false) + checkEvaluation("abc" like "bc%", false) + } + + test("LIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) + checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef"))) + checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b"))) + checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b"))) + checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b"))) + checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b"))) + checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b"))) + checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%"))) + checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**"))) + checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) + checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) + checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) + } + + test("RLIKE literal Regular Expression") { + checkEvaluation("abdef" rlike "abdef", true) + checkEvaluation("abbbbc" rlike "a.*c", true) + + checkEvaluation("fofo" rlike "^fo", true) + checkEvaluation("fo\no" rlike "^fo\no$", true) + checkEvaluation("Bn" rlike "^Ba*n", true) + checkEvaluation("afofo" rlike "fo", true) + checkEvaluation("afofo" rlike "^fo", false) + checkEvaluation("Baan" rlike "^Ba?n", false) + checkEvaluation("axe" rlike "pi|apa", false) + checkEvaluation("pip" rlike "^(pi)*$", false) + + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike "**") + } + } + + test("RLIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef"))) + checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c"))) + checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo"))) + checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$"))) + checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n"))) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) + } + } } + diff --git a/sql/core/pom.xml b/sql/core/pom.xml index e367edfb1f562..85580ed6b822f 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -30,6 +30,17 @@ jar Spark Project SQL http://spark.apache.org/ + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cf3c06acce5b0..f4bf00f4cffa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution._ /** @@ -111,13 +112,42 @@ class SQLContext(@transient val sparkContext: SparkContext) result } + /** Returns the specified table as a SchemaRDD */ + def table(tableName: String): SchemaRDD = + new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = { + val currentTable = catalog.lookupRelation(None, tableName) + val asInMemoryRelation = + InMemoryColumnarTableScan(currentTable.output, executePlan(currentTable).executedPlan) + + catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation)) + } + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = { + EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { + // This is kind of a hack to make sure that if this was just an RDD registered as a table, + // we reregister the RDD as a table. + case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd)) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + catalog.registerTable(None, tableName, SparkLogicalPlan(e)) + case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") + } + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext val strategies: Seq[Strategy] = - TopK :: + TakeOrdered :: PartialAggregation :: - SparkEquiInnerJoin :: + HashJoin :: ParquetOperations :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index e0c98ecdf8f22..ffd4894b5213d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -21,7 +21,7 @@ import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.catalyst.types.{BinaryType, NativeType, DataType} import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is @@ -41,121 +41,66 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](buffer: ByteBuffer) +private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( + protected val buffer: ByteBuffer, + protected val columnType: ColumnType[T, JvmType]) extends ColumnAccessor { protected def initialize() {} - def columnType: ColumnType[T, JvmType] - def hasNext = buffer.hasRemaining def extractTo(row: MutableRow, ordinal: Int) { - doExtractTo(row, ordinal) + columnType.setField(row, ordinal, extractSingle(buffer)) } - protected def doExtractTo(row: MutableRow, ordinal: Int) + def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer) protected def underlyingBuffer = buffer } private[sql] abstract class NativeColumnAccessor[T <: NativeType]( - buffer: ByteBuffer, - val columnType: NativeColumnType[T]) - extends BasicColumnAccessor[T, T#JvmType](buffer) + override protected val buffer: ByteBuffer, + override protected val columnType: NativeColumnType[T]) + extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor + with CompressibleColumnAccessor[T] private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BOOLEAN) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setBoolean(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, BOOLEAN) private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setInt(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, INT) private[sql] class ShortColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, SHORT) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setShort(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, SHORT) private[sql] class LongColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, LONG) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setLong(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, LONG) private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setByte(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, BYTE) private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setDouble(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, DOUBLE) private[sql] class FloatColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, FLOAT) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setFloat(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, FLOAT) private[sql] class StringColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, STRING) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setString(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, STRING) private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer) - with NullableColumnAccessor { - - def columnType = BINARY - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row(ordinal) = columnType.extract(buffer) - } -} + extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) + with NullableColumnAccessor private[sql] class GenericColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[DataType, Array[Byte]](buffer) - with NullableColumnAccessor { - - def columnType = GENERIC - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - val serialized = columnType.extract(buffer) - row(ordinal) = SparkSqlSerializer.deserialize[Any](serialized) - } -} + extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) + with NullableColumnAccessor private[sql] object ColumnAccessor { - def apply(b: ByteBuffer): ColumnAccessor = { - // The first 4 bytes in the buffer indicates the column type. - val buffer = b.duplicate().order(ByteOrder.nativeOrder()) + def apply(buffer: ByteBuffer): ColumnAccessor = { + // The first 4 bytes in the buffer indicate the column type. val columnTypeId = buffer.getInt() columnTypeId match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 3e622adfd3d6a..048ee66bff44b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} private[sql] trait ColumnBuilder { /** @@ -30,37 +30,44 @@ private[sql] trait ColumnBuilder { */ def initialize(initialSize: Int, columnName: String = "") + /** + * Appends `row(ordinal)` to the column builder. + */ def appendFrom(row: Row, ordinal: Int) + /** + * Column statistics information + */ + def columnStats: ColumnStats[_, _] + + /** + * Returns the final columnar byte buffer. + */ def build(): ByteBuffer } -private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends ColumnBuilder { +private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( + val columnStats: ColumnStats[T, JvmType], + val columnType: ColumnType[T, JvmType]) + extends ColumnBuilder { - private var columnName: String = _ - protected var buffer: ByteBuffer = _ + protected var columnName: String = _ - def columnType: ColumnType[T, JvmType] + protected var buffer: ByteBuffer = _ override def initialize(initialSize: Int, columnName: String = "") = { val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize this.columnName = columnName - buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize) + + // Reserves 4 bytes for column type ID + buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize) buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - // Have to give a concrete implementation to make mixin possible override def appendFrom(row: Row, ordinal: Int) { - doAppendFrom(row, ordinal) - } - - // Concrete `ColumnBuilder`s can override this method to append values - protected def doAppendFrom(row: Row, ordinal: Int) - - // Helper method to append primitive values (to avoid boxing cost) - protected def appendValue(v: JvmType) { - buffer = ensureFreeSpace(buffer, columnType.actualSize(v)) - columnType.append(v, buffer) + val field = columnType.getField(row, ordinal) + buffer = ensureFreeSpace(buffer, columnType.actualSize(field)) + columnType.append(field, buffer) } override def build() = { @@ -69,83 +76,39 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends C } } -private[sql] abstract class NativeColumnBuilder[T <: NativeType]( - val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#JvmType] +private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) with NullableColumnBuilder -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(BOOLEAN) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getBoolean(ordinal)) - } -} - -private[sql] class IntColumnBuilder extends NativeColumnBuilder(INT) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getInt(ordinal)) - } -} +private[sql] abstract class NativeColumnBuilder[T <: NativeType]( + override val columnStats: NativeColumnStats[T], + override val columnType: NativeColumnType[T]) + extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) + with NullableColumnBuilder + with AllCompressionSchemes + with CompressibleColumnBuilder[T] -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(SHORT) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getShort(ordinal)) - } -} +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class LongColumnBuilder extends NativeColumnBuilder(LONG) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getLong(ordinal)) - } -} +private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(BYTE) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getByte(ordinal)) - } -} +private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(DOUBLE) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getDouble(ordinal)) - } -} +private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(FLOAT) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getFloat(ordinal)) - } -} +private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(STRING) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getString(ordinal)) - } -} +private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class BinaryColumnBuilder - extends BasicColumnBuilder[BinaryType.type, Array[Byte]] - with NullableColumnBuilder { +private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) - def columnType = BINARY +private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row(ordinal).asInstanceOf[Array[Byte]]) - } -} +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends BasicColumnBuilder[DataType, Array[Byte]] - with NullableColumnBuilder { - - def columnType = GENERIC - - override def doAppendFrom(row: Row, ordinal: Int) { - val serialized = SparkSqlSerializer.serialize(row(ordinal)) - buffer = ColumnBuilder.ensureFreeSpace(buffer, columnType.actualSize(serialized)) - columnType.append(serialized, buffer) - } -} +private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala new file mode 100644 index 0000000000000..30c6bdc7912fc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.types._ + +private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable { + /** + * Closed lower bound of this column. + */ + def lowerBound: JvmType + + /** + * Closed upper bound of this column. + */ + def upperBound: JvmType + + /** + * Gathers statistics information from `row(ordinal)`. + */ + def gatherStats(row: Row, ordinal: Int) + + /** + * Returns `true` if `lower <= row(ordinal) <= upper`. + */ + def contains(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `row(ordinal) < upper` holds. + */ + def isAbove(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `lower < row(ordinal)` holds. + */ + def isBelow(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `row(ordinal) <= upper` holds. + */ + def isAtOrAbove(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `lower <= row(ordinal)` holds. + */ + def isAtOrBelow(row: Row, ordinal: Int): Boolean +} + +private[sql] sealed abstract class NativeColumnStats[T <: NativeType] + extends ColumnStats[T, T#JvmType] { + + type JvmType = T#JvmType + + protected var (_lower, _upper) = initialBounds + + def initialBounds: (JvmType, JvmType) + + protected def columnType: NativeColumnType[T] + + override def lowerBound: T#JvmType = _lower + + override def upperBound: T#JvmType = _upper + + override def isAtOrAbove(row: Row, ordinal: Int) = { + contains(row, ordinal) || isAbove(row, ordinal) + } + + override def isAtOrBelow(row: Row, ordinal: Int) = { + contains(row, ordinal) || isBelow(row, ordinal) + } +} + +private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] { + override def isAtOrBelow(row: Row, ordinal: Int) = true + + override def isAtOrAbove(row: Row, ordinal: Int) = true + + override def isBelow(row: Row, ordinal: Int) = true + + override def isAbove(row: Row, ordinal: Int) = true + + override def contains(row: Row, ordinal: Int) = true + + override def gatherStats(row: Row, ordinal: Int) {} + + override def upperBound = null.asInstanceOf[JvmType] + + override def lowerBound = null.asInstanceOf[JvmType] +} + +private[sql] abstract class BasicColumnStats[T <: NativeType]( + protected val columnType: NativeColumnType[T]) + extends NativeColumnStats[T] + +private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) { + override def initialBounds = (true, false) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) { + override def initialBounds = (Byte.MaxValue, Byte.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) { + override def initialBounds = (Short.MaxValue, Short.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class LongColumnStats extends BasicColumnStats(LONG) { + override def initialBounds = (Long.MaxValue, Long.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) { + override def initialBounds = (Double.MaxValue, Double.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) { + override def initialBounds = (Float.MaxValue, Float.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] object IntColumnStats { + val UNINITIALIZED = 0 + val INITIALIZED = 1 + val ASCENDING = 2 + val DESCENDING = 3 + val UNORDERED = 4 +} + +/** + * Statistical information for `Int` columns. More information is collected since `Int` is + * frequently used. Extra information include: + * + * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search + * is applicable when searching elements. + * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression + * scheme. + * + * (This two kinds of information are not used anywhere yet and might be removed later.) + */ +private[sql] class IntColumnStats extends BasicColumnStats(INT) { + import IntColumnStats._ + + private var orderedState = UNINITIALIZED + private var lastValue: Int = _ + private var _maxDelta: Int = _ + + def isAscending = orderedState != DESCENDING && orderedState != UNORDERED + def isDescending = orderedState != ASCENDING && orderedState != UNORDERED + def isOrdered = isAscending || isDescending + def maxDelta = _maxDelta + + override def initialBounds = (Int.MaxValue, Int.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + + orderedState = orderedState match { + case UNINITIALIZED => + lastValue = field + INITIALIZED + + case INITIALIZED => + // If all the integers in the column are the same, ordered state is set to Ascending. + // TODO (lian) Confirm whether this is the standard behaviour. + val nextState = if (field >= lastValue) ASCENDING else DESCENDING + _maxDelta = math.abs(field - lastValue) + lastValue = field + nextState + + case ASCENDING if field < lastValue => + UNORDERED + + case DESCENDING if field > lastValue => + UNORDERED + + case state @ (ASCENDING | DESCENDING) => + _maxDelta = _maxDelta.max(field - lastValue) + lastValue = field + state + + case _ => + orderedState + } + } +} + +private[sql] class StringColumnStats extends BasicColumnStats(STRING) { + override def initialBounds = (null, null) + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field + if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field + } + + override def contains(row: Row, ordinal: Int) = { + !(upperBound eq null) && { + val field = columnType.getField(row, ordinal) + lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 + } + } + + override def isAbove(row: Row, ordinal: Int) = { + !(upperBound eq null) && { + val field = columnType.getField(row, ordinal) + field.compareTo(upperBound) < 0 + } + } + + override def isBelow(row: Row, ordinal: Int) = { + !(lowerBound eq null) && { + val field = columnType.getField(row, ordinal) + lowerBound.compareTo(field) < 0 + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index a452b86f0cda3..5be76890afe31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -19,7 +19,12 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.SparkSqlSerializer /** * An abstract class that represents type of a column. Used to append/extract Java objects into/from @@ -50,10 +55,24 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def actualSize(v: JvmType): Int = defaultSize + /** + * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs + * whenever possible. + */ + def getField(row: Row, ordinal: Int): JvmType + + /** + * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing + * costs whenever possible. + */ + def setField(row: MutableRow, ordinal: Int, value: JvmType) + /** * Creates a duplicated copy of the value. */ def clone(v: JvmType): JvmType = v + + override def toString = getClass.getSimpleName.stripSuffix("$") } private[sql] abstract class NativeColumnType[T <: NativeType]( @@ -65,7 +84,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. */ - def scalaTag = dataType.tag + def scalaTag: TypeTag[dataType.JvmType] = dataType.tag } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { @@ -76,6 +95,12 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { def extract(buffer: ByteBuffer) = { buffer.getInt() } + + override def setField(row: MutableRow, ordinal: Int, value: Int) { + row.setInt(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) } private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { @@ -86,6 +111,12 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { override def extract(buffer: ByteBuffer) = { buffer.getLong() } + + override def setField(row: MutableRow, ordinal: Int, value: Long) { + row.setLong(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) } private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { @@ -96,6 +127,12 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { override def extract(buffer: ByteBuffer) = { buffer.getFloat() } + + override def setField(row: MutableRow, ordinal: Int, value: Float) { + row.setFloat(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { @@ -106,6 +143,12 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { override def extract(buffer: ByteBuffer) = { buffer.getDouble() } + + override def setField(row: MutableRow, ordinal: Int, value: Double) { + row.setDouble(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { @@ -116,6 +159,12 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { override def extract(buffer: ByteBuffer) = { if (buffer.get() == 1) true else false } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean) { + row.setBoolean(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) } private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { @@ -126,6 +175,12 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { override def extract(buffer: ByteBuffer) = { buffer.get() } + + override def setField(row: MutableRow, ordinal: Int, value: Byte) { + row.setByte(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) } private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { @@ -136,6 +191,12 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { override def extract(buffer: ByteBuffer) = { buffer.getShort() } + + override def setField(row: MutableRow, ordinal: Int, value: Short) { + row.setShort(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { @@ -152,6 +213,12 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { buffer.get(stringBytes, 0, length) new String(stringBytes) } + + override def setField(row: MutableRow, ordinal: Int, value: String) { + row.setString(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getString(ordinal) } private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( @@ -173,15 +240,27 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) +private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + row(ordinal) = value + } + + override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]] +} // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) +private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + row(ordinal) = SparkSqlSerializer.deserialize[Any](value) + } + + override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal)) +} private[sql] object ColumnType { - implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = { + def apply(dataType: DataType): ColumnType[_, _] = { dataType match { case IntegerType => INT case LongType => LONG diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala rename to sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index f853759e5a306..8a24733047423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} import org.apache.spark.sql.execution.{SparkPlan, LeafNode} import org.apache.spark.sql.Row -/* Implicit conversions */ -import org.apache.spark.sql.columnar.ColumnType._ - private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan) extends LeafNode { @@ -32,8 +29,8 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch lazy val cachedColumnBuffers = { val output = child.output val cached = child.execute().mapPartitions { iterator => - val columnBuilders = output.map { a => - ColumnBuilder(a.dataType.typeId, 0, a.name) + val columnBuilders = output.map { attribute => + ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name) }.toArray var row: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index 2970c609b928d..7d49ab07f7a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { private var nextNullIndex: Int = _ private var pos: Int = 0 - abstract override def initialize() { + abstract override protected def initialize() { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) nullCount = nullsBuffer.getInt() nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index 048d1f05c7df2..2a3b6fc1e46d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -22,10 +22,18 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row /** - * Builds a nullable column. The byte buffer of a nullable column contains: - * - 4 bytes for the null count (number of nulls) - * - positions for each null, in ascending order - * - the non-null data (column data type, compression type, data...) + * A stackable trait used for building byte buffer for a column containing null values. Memory + * layout of the final byte buffer is: + * {{{ + * .----------------------- Column type ID (4 bytes) + * | .------------------- Null count N (4 bytes) + * | | .--------------- Null positions (4 x N bytes, empty if null count is zero) + * | | | .--------- Non-null elements + * V V V V + * +---+---+-----+---------+ + * | | | ... | ... ... | + * +---+---+-----+---------+ + * }}} */ private[sql] trait NullableColumnBuilder extends ColumnBuilder { private var nulls: ByteBuffer = _ @@ -59,19 +67,8 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { nulls.limit(nullDataLen) nulls.rewind() - // Column type ID is moved to the front, follows the null count, then non-null data - // - // +---------+ - // | 4 bytes | Column type ID - // +---------+ - // | 4 bytes | Null count - // +---------+ - // | ... | Null positions (if null count is not zero) - // +---------+ - // | ... | Non-null part (without column type ID) - // +---------+ val buffer = ByteBuffer - .allocate(4 + nullDataLen + nonNulls.limit) + .allocate(4 + 4 + nullDataLen + nonNulls.remaining()) .order(ByteOrder.nativeOrder()) .putInt(typeId) .putInt(nullCount) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala new file mode 100644 index 0000000000000..878cb84de106f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} + +private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor { + this: NativeColumnAccessor[T] => + + private var decoder: Decoder[T] = _ + + abstract override protected def initialize() = { + super.initialize() + decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType) + } + + abstract override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala new file mode 100644 index 0000000000000..3ac4b358ddf83 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} + +/** + * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of + * the final byte buffer is: + * {{{ + * .--------------------------- Column type ID (4 bytes) + * | .----------------------- Null count N (4 bytes) + * | | .------------------- Null positions (4 x N bytes, empty if null count is zero) + * | | | .------------- Compression scheme ID (4 bytes) + * | | | | .--------- Compressed non-null elements + * V V V V V + * +---+---+-----+---+---------+ + * | | | ... | | ... ... | + * +---+---+-----+---+---------+ + * \-----------/ \-----------/ + * header body + * }}} + */ +private[sql] trait CompressibleColumnBuilder[T <: NativeType] + extends ColumnBuilder with Logging { + + this: NativeColumnBuilder[T] with WithCompressionSchemes => + + import CompressionScheme._ + + val compressionEncoders = schemes.filter(_.supports(columnType)).map(_.encoder) + + protected def isWorthCompressing(encoder: Encoder) = { + encoder.compressionRatio < 0.8 + } + + private def gatherCompressibilityStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + + var i = 0 + while (i < compressionEncoders.length) { + compressionEncoders(i).gatherCompressibilityStats(field, columnType) + i += 1 + } + } + + abstract override def appendFrom(row: Row, ordinal: Int) { + super.appendFrom(row, ordinal) + gatherCompressibilityStats(row, ordinal) + } + + abstract override def build() = { + val rawBuffer = super.build() + val encoder = { + val candidate = compressionEncoders.minBy(_.compressionRatio) + if (isWorthCompressing(candidate)) candidate else PassThrough.encoder + } + + val headerSize = columnHeaderSize(rawBuffer) + val compressedSize = if (encoder.compressedSize == 0) { + rawBuffer.limit - headerSize + } else { + encoder.compressedSize + } + + // Reserves 4 bytes for compression scheme ID + val compressedBuffer = ByteBuffer + .allocate(headerSize + 4 + compressedSize) + .order(ByteOrder.nativeOrder) + + copyColumnHeader(rawBuffer, compressedBuffer) + + logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(rawBuffer, compressedBuffer, columnType) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala new file mode 100644 index 0000000000000..d3a4ac8df926b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} + +private[sql] trait Encoder { + def gatherCompressibilityStats[T <: NativeType]( + value: T#JvmType, + columnType: ColumnType[T, T#JvmType]) {} + + def compressedSize: Int + + def uncompressedSize: Int + + def compressionRatio: Double = { + if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 + } + + def compress[T <: NativeType]( + from: ByteBuffer, + to: ByteBuffer, + columnType: ColumnType[T, T#JvmType]): ByteBuffer +} + +private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType] + +private[sql] trait CompressionScheme { + def typeId: Int + + def supports(columnType: ColumnType[_, _]): Boolean + + def encoder: Encoder + + def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] +} + +private[sql] trait WithCompressionSchemes { + def schemes: Seq[CompressionScheme] +} + +private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { + override val schemes: Seq[CompressionScheme] = { + Seq(PassThrough, RunLengthEncoding, DictionaryEncoding) + } +} + +private[sql] object CompressionScheme { + def apply(typeId: Int): CompressionScheme = typeId match { + case PassThrough.typeId => PassThrough + case _ => throw new UnsupportedOperationException() + } + + def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) { + // Writes column type ID + to.putInt(from.getInt()) + + // Writes null count + val nullCount = from.getInt() + to.putInt(nullCount) + + // Writes null positions + var i = 0 + while (i < nullCount) { + to.putInt(from.getInt()) + i += 1 + } + } + + def columnHeaderSize(columnBuffer: ByteBuffer): Int = { + val header = columnBuffer.duplicate() + val nullCount = header.getInt(4) + // Column type ID + null count + null positions + 4 + 4 + 4 * nullCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala new file mode 100644 index 0000000000000..dc2c153faf8ad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.runtimeMirror + +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar._ + +private[sql] case object PassThrough extends CompressionScheme { + override val typeId = 0 + + override def supports(columnType: ColumnType[_, _]) = true + + override def encoder = new this.Encoder + + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder(buffer, columnType) + } + + class Encoder extends compression.Encoder { + override def uncompressedSize = 0 + + override def compressedSize = 0 + + override def compress[T <: NativeType]( + from: ByteBuffer, + to: ByteBuffer, + columnType: ColumnType[T, T#JvmType]) = { + + // Writes compression type ID and copies raw contents + to.putInt(PassThrough.typeId).put(from).rewind() + to + } + } + + class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + extends compression.Decoder[T] { + + override def next() = columnType.extract(buffer) + + override def hasNext = buffer.hasRemaining + } +} + +private[sql] case object RunLengthEncoding extends CompressionScheme { + override def typeId = 1 + + override def encoder = new this.Encoder + + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder(buffer, columnType) + } + + override def supports(columnType: ColumnType[_, _]) = columnType match { + case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true + case _ => false + } + + class Encoder extends compression.Encoder { + private var _uncompressedSize = 0 + private var _compressedSize = 0 + + // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. + private val lastValue = new GenericMutableRow(1) + private var lastRun = 0 + + override def uncompressedSize = _uncompressedSize + + override def compressedSize = _compressedSize + + override def gatherCompressibilityStats[T <: NativeType]( + value: T#JvmType, + columnType: ColumnType[T, T#JvmType]) { + + val actualSize = columnType.actualSize(value) + _uncompressedSize += actualSize + + if (lastValue.isNullAt(0)) { + columnType.setField(lastValue, 0, value) + lastRun = 1 + _compressedSize += actualSize + 4 + } else { + if (columnType.getField(lastValue, 0) == value) { + lastRun += 1 + } else { + _compressedSize += actualSize + 4 + columnType.setField(lastValue, 0, value) + lastRun = 1 + } + } + } + + override def compress[T <: NativeType]( + from: ByteBuffer, + to: ByteBuffer, + columnType: ColumnType[T, T#JvmType]) = { + + to.putInt(RunLengthEncoding.typeId) + + if (from.hasRemaining) { + var currentValue = columnType.extract(from) + var currentRun = 1 + + while (from.hasRemaining) { + val value = columnType.extract(from) + + if (value == currentValue) { + currentRun += 1 + } else { + // Writes current run + columnType.append(currentValue, to) + to.putInt(currentRun) + + // Resets current run + currentValue = value + currentRun = 1 + } + } + + // Writes the last run + columnType.append(currentValue, to) + to.putInt(currentRun) + } + + to.rewind() + to + } + } + + class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + extends compression.Decoder[T] { + + private var run = 0 + private var valueCount = 0 + private var currentValue: T#JvmType = _ + + override def next() = { + if (valueCount == run) { + currentValue = columnType.extract(buffer) + run = buffer.getInt() + valueCount = 1 + } else { + valueCount += 1 + } + + currentValue + } + + override def hasNext = buffer.hasRemaining + } +} + +private[sql] case object DictionaryEncoding extends CompressionScheme { + override def typeId: Int = 2 + + // 32K unique values allowed + private val MAX_DICT_SIZE = Short.MaxValue - 1 + + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder[T](buffer, columnType) + } + + override def encoder = new this.Encoder + + override def supports(columnType: ColumnType[_, _]) = columnType match { + case INT | LONG | STRING => true + case _ => false + } + + class Encoder extends compression.Encoder{ + // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary + // overflows. + private var _uncompressedSize = 0 + + // If the number of distinct elements is too large, we discard the use of dictionary encoding + // and set the overflow flag to true. + private var overflow = false + + // Total number of elements. + private var count = 0 + + // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. + private var values = new mutable.ArrayBuffer[Any](1024) + + // The dictionary that maps a value to the encoded short integer. + private val dictionary = mutable.HashMap.empty[Any, Short] + + // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int` + // to store dictionary element count. + private var dictionarySize = 4 + + override def gatherCompressibilityStats[T <: NativeType]( + value: T#JvmType, + columnType: ColumnType[T, T#JvmType]) { + + if (!overflow) { + val actualSize = columnType.actualSize(value) + count += 1 + _uncompressedSize += actualSize + + if (!dictionary.contains(value)) { + if (dictionary.size < MAX_DICT_SIZE) { + val clone = columnType.clone(value) + values += clone + dictionarySize += actualSize + dictionary(clone) = dictionary.size.toShort + } else { + overflow = true + values.clear() + dictionary.clear() + } + } + } + } + + override def compress[T <: NativeType]( + from: ByteBuffer, + to: ByteBuffer, + columnType: ColumnType[T, T#JvmType]) = { + + if (overflow) { + throw new IllegalStateException( + "Dictionary encoding should not be used because of dictionary overflow.") + } + + to.putInt(DictionaryEncoding.typeId) + .putInt(dictionary.size) + + var i = 0 + while (i < values.length) { + columnType.append(values(i).asInstanceOf[T#JvmType], to) + i += 1 + } + + while (from.hasRemaining) { + to.putShort(dictionary(columnType.extract(from))) + } + + to.rewind() + to + } + + override def uncompressedSize = _uncompressedSize + + override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2 + } + + class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + extends compression.Decoder[T] { + + private val dictionary = { + // TODO Can we clean up this mess? Maybe move this to `DataType`? + implicit val classTag = { + val mirror = runtimeMirror(getClass.getClassLoader) + ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe)) + } + + Array.fill(buffer.getInt()) { + columnType.extract(buffer) + } + } + + override def next() = dictionary(buffer.getShort()) + + override def hasNext = buffer.hasRemaining + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 915f551fb2f01..d8e1b970c1d88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -32,7 +32,13 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[Array[Any]]) + // This is kinda hacky... kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer) + kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 86f9d3e0fa954..b3e51fdf75270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => - object SparkEquiInnerJoin extends Strategy { + object HashJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) => logger.debug(s"Considering join: ${predicates ++ condition}") @@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val leftKeys = joinKeys.map(_._1) val rightKeys = joinKeys.map(_._2) - val joinOp = execution.SparkEquiInnerJoin( - leftKeys, rightKeys, planLater(left), planLater(right)) + val joinOp = execution.HashJoin( + leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) // Make sure other conditions are met if present. if (otherPredicates.nonEmpty) { @@ -158,10 +158,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case other => other } - object TopK extends Strategy { + object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => + execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil case _ => Nil } } @@ -213,8 +213,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sparkContext.parallelize(data.map(r => new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) execution.ExistingRdd(output, dataAsRdd) :: Nil - case logical.StopAfter(IntegerLiteral(limit), child) => - execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.Limit(limit, planLater(child))(sparkContext) :: Nil case Unions(unionChildren) => execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil case logical.Generate(generator, join, outer, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 65cb8f8becefa..524e5022ee14b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -19,27 +19,28 @@ package org.apache.spark.sql.execution import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext - +import org.apache.spark.{HashPartitioner, SparkConf, SparkContext} +import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.util.MutablePair + case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - def output = projectList.map(_.toAttribute) + override def output = projectList.map(_.toAttribute) - def execute() = child.execute().mapPartitions { iter => + override def execute() = child.execute().mapPartitions { iter => @transient val reusableProjection = new MutableProjection(projectList) iter.map(reusableProjection) } } case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { - def output = child.output + override def output = child.output - def execute() = child.execute().mapPartitions { iter => + override def execute() = child.execute().mapPartitions { iter => iter.filter(condition.apply(_).asInstanceOf[Boolean]) } } @@ -47,37 +48,59 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan) extends UnaryNode { - def output = child.output + override def output = child.output // TODO: How to pick seed? - def execute() = child.execute().sample(withReplacement, fraction, seed) + override def execute() = child.execute().sample(withReplacement, fraction, seed) } case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes - def output = children.head.output - def execute() = sc.union(children.map(_.execute())) + override def output = children.head.output + override def execute() = sc.union(children.map(_.execute())) override def otherCopyArgs = sc :: Nil } -case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { +/** + * Take the first limit elements. Note that the implementation is different depending on whether + * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, + * this operator uses Spark's take method on the Spark driver. If it is not terminal or is + * invoked using execute, we first take the limit on each partition, and then repartition all the + * data to a single partition to compute the global limit. + */ +case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { + // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: + // partition local limit -> exchange into one partition -> partition local limit again + override def otherCopyArgs = sc :: Nil - def output = child.output + override def output = child.output override def executeCollect() = child.execute().map(_.copy()).take(limit) - // TODO: Terminal split should be implemented differently from non-terminal split. - // TODO: Pick num splits based on |limit|. - def execute() = sc.makeRDD(executeCollect(), 1) + override def execute() = { + val rdd = child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Boolean, Row]() + iter.take(limit).map(row => mutablePair.update(false, row)) + } + val part = new HashPartitioner(1) + val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.mapPartitions(_.take(limit).map(_._2)) + } } -case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sc: SparkContext) extends UnaryNode { +/** + * Take the first limit elements as defined by the sortOrder. This is logically equivalent to + * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but + * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + */ +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) + (@transient sc: SparkContext) extends UnaryNode { override def otherCopyArgs = sc :: Nil - def output = child.output + override def output = child.output @transient lazy val ordering = new RowOrdering(sortOrder) @@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - def execute() = sc.makeRDD(executeCollect(), 1) + override def execute() = sc.makeRDD(executeCollect(), 1) } @@ -101,7 +124,7 @@ case class Sort( @transient lazy val ordering = new RowOrdering(sortOrder) - def execute() = attachTree(this, "sort") { + override def execute() = attachTree(this, "sort") { // TODO: Optimize sorting operation? child.execute() .mapPartitions( @@ -109,7 +132,7 @@ case class Sort( preservesPartitioning = true) } - def output = child.output + override def output = child.output } object ExistingRdd { @@ -130,6 +153,6 @@ object ExistingRdd { } case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - def execute() = rdd + override def execute() = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index f0d21143ba5d1..c89dae9358bf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.execution -import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, BitSet} -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} -import org.apache.spark.rdd.PartitionLocalRDDFunctions._ +sealed abstract class BuildSide +case object BuildLeft extends BuildSide +case object BuildRight extends BuildSide -case class SparkEquiInnerJoin( +case class HashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + buildSide: BuildSide, left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -40,33 +41,93 @@ case class SparkEquiInnerJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + def output = left.output ++ right.output - def execute() = attachTree(this, "execute") { - val leftWithKeys = left.execute().mapPartitions { iter => - val generateLeftKeys = new Projection(leftKeys, left.output) - iter.map(row => (generateLeftKeys(row), row.copy())) - } + @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val streamSideKeyGenerator = + () => new MutableProjection(streamedKeys, streamedPlan.output) - val rightWithKeys = right.execute().mapPartitions { iter => - val generateRightKeys = new Projection(rightKeys, right.output) - iter.map(row => (generateRightKeys(row), row.copy())) - } + def execute() = { - // Do the join. - val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys)) - // Drop join keys and merge input tuples. - joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) } - } + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + // TODO: Use Spark's HashMap implementation. + val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 - /** - * Filters any rows where the any of the join keys is null, ensuring three-valued - * logic for the equi-join conditions. - */ - protected def filterNulls(rdd: RDD[(Row, Row)]) = - rdd.filter { - case (key: Seq[_], _) => !key.exists(_ == null) + // Mutable per row objects. + private[this] val joinRow = new JoinedRow + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } } + } } case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin( def right = broadcast @transient lazy val boundCondition = - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true)) + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) def execute() = { val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new mutable.ArrayBuffer[Row] - val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) + val matchedRows = new ArrayBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => @@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin( while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { matchedRows += buildRow(streamedRow ++ broadcastedRow) matched = true includedBroadcastTuples += i diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 2b825f84ee910..4ab755c096bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -36,7 +36,7 @@ import parquet.schema.{MessageType, MessageTypeParser} import parquet.schema.{PrimitiveType => ParquetPrimitiveType} import parquet.schema.{Type => ParquetType} -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} import org.apache.spark.sql.catalyst.types._ @@ -54,26 +54,37 @@ import org.apache.spark.sql.catalyst.types._ * @param tableName The name of the relation that can be used in queries. * @param path The path to the Parquet file. */ -case class ParquetRelation(tableName: String, path: String) extends BaseRelation { +case class ParquetRelation(tableName: String, path: String) + extends BaseRelation with MultiInstanceRelation { - /** Schema derived from ParquetFile **/ + /** Schema derived from ParquetFile */ def parquetSchema: MessageType = ParquetTypesConverter .readMetaData(new Path(path)) .getFileMetaData .getSchema - /** Attributes **/ + /** Attributes */ val attributes = ParquetTypesConverter .convertToAttributes(parquetSchema) - /** Output **/ + /** Output */ override val output = attributes // Parquet files have no concepts of keys, therefore no Partitioner // Note: we could allow Block level access; needs to be thought through override def isPartitioned = false + + override def newInstance = ParquetRelation(tableName, path).asInstanceOf[this.type] + + // Equals must also take into account the output attributes so that we can distinguish between + // different instances of the same relation, + override def equals(other: Any) = other match { + case p: ParquetRelation => + p.tableName == tableName && p.path == path && p.output == output + case _ => false + } } object ParquetRelation { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala new file mode 100644 index 0000000000000..e5902c3cae381 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan + +class CachedTableSuite extends QueryTest { + TestData // Load test tables. + + test("read from cached table and uncache") { + TestSQLContext.cacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + + TestSQLContext.uncacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestSQLContext.uncacheTable("testData") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index fa4a1d5189ea6..4c4fd6dbbedb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -216,4 +216,31 @@ class SQLQuerySuite extends QueryTest { (null, null, 5, "E") :: (null, null, 6, "F") :: Nil) } + + test("select with table name as qualifier") { + checkAnswer( + sql("SELECT testData.value FROM testData WHERE testData.key = 1"), + Seq(Seq("1"))) + } + + test("inner join ON with table name as qualifier") { + checkAnswer( + sql("SELECT * FROM upperCaseData JOIN lowerCaseData ON lowerCaseData.n = upperCaseData.N"), + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d"))) + } + + test("qualified select with inner join ON with table name as qualifier") { + checkAnswer( + sql("SELECT upperCaseData.N, upperCaseData.L FROM upperCaseData JOIN lowerCaseData " + + "ON lowerCaseData.n = upperCaseData.N"), + Seq( + (1, "A"), + (2, "B"), + (3, "C"), + (4, "D"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala new file mode 100644 index 0000000000000..70033a050c78c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +import org.apache.spark.sql.test.TestSQLContext._ + +case class ReflectData( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + decimalField: BigDecimal, + seqInt: Seq[Int]) + +case class ReflectBinary(data: Array[Byte]) + +class ScalaReflectionRelationSuite extends FunSuite { + test("query case class RDD") { + val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + BigDecimal(1), Seq(1,2,3)) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.registerAsTable("reflectData") + + assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) + } + + // Equality is broken for Arrays, so we test that separately. + test("query binary data") { + val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) + rdd.registerAsTable("reflectBinary") + + val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + assert(result.toSeq === Seq[Byte](1)) + } +} \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala new file mode 100644 index 0000000000000..78640b876d4aa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.types._ + +class ColumnStatsSuite extends FunSuite { + testColumnStats(classOf[BooleanColumnStats], BOOLEAN) + testColumnStats(classOf[ByteColumnStats], BYTE) + testColumnStats(classOf[ShortColumnStats], SHORT) + testColumnStats(classOf[IntColumnStats], INT) + testColumnStats(classOf[LongColumnStats], LONG) + testColumnStats(classOf[FloatColumnStats], FLOAT) + testColumnStats(classOf[DoubleColumnStats], DOUBLE) + testColumnStats(classOf[StringColumnStats], STRING) + + def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]]( + columnStatsClass: Class[U], + columnType: NativeColumnType[T]) { + + val columnStatsName = columnStatsClass.getSimpleName + + test(s"$columnStatsName: empty") { + val columnStats = columnStatsClass.newInstance() + expectResult(columnStats.initialBounds, "Wrong initial bounds") { + (columnStats.lowerBound, columnStats.upperBound) + } + } + + test(s"$columnStatsName: non-empty") { + import ColumnarTestUtils._ + + val columnStats = columnStatsClass.newInstance() + val rows = Seq.fill(10)(makeRandomRow(columnType)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.map(_.head.asInstanceOf[T#JvmType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + + expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound) + expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 2d431affbcfcc..1d3608ed2d9ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -19,46 +19,56 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import scala.util.Random - import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer class ColumnTypeSuite extends FunSuite { - val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC) + val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { - val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16) + val checks = Map( + INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, + BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16) - columnTypes.zip(defaultSize).foreach { case (columnType, size) => - assert(columnType.defaultSize === size) + checks.foreach { case (columnType, expectedSize) => + expectResult(expectedSize, s"Wrong defaultSize for $columnType") { + columnType.defaultSize + } } } test("actualSize") { - val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11) - val actualSizes = Seq( - INT.actualSize(Int.MaxValue), - SHORT.actualSize(Short.MaxValue), - LONG.actualSize(Long.MaxValue), - BYTE.actualSize(Byte.MaxValue), - DOUBLE.actualSize(Double.MaxValue), - FLOAT.actualSize(Float.MaxValue), - STRING.actualSize("hello"), - BINARY.actualSize(new Array[Byte](4)), - GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a")))) - - expectedSizes.zip(actualSizes).foreach { case (expected, actual) => - assert(expected === actual) + def checkActualSize[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType], + value: JvmType, + expected: Int) { + + expectResult(expected, s"Wrong actualSize for $columnType") { + columnType.actualSize(value) + } } + + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(LONG, Long.MaxValue, 8) + checkActualSize(BYTE, Byte.MaxValue, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(STRING, "hello", 4 + 5) + + val binary = Array.fill[Byte](4)(0: Byte) + checkActualSize(BINARY, binary, 4 + 4) + + val generic = Map(1 -> "a") + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) } - testNumericColumnType[BooleanType.type, Boolean]( + testNativeColumnType[BooleanType.type]( BOOLEAN, - Array.fill(4)(Random.nextBoolean()), - ByteBuffer.allocate(32), (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -66,105 +76,42 @@ class ColumnTypeSuite extends FunSuite { buffer.get() == 1 }) - testNumericColumnType[IntegerType.type, Int]( - INT, - Array.fill(4)(Random.nextInt()), - ByteBuffer.allocate(32), - (_: ByteBuffer).putInt(_), - (_: ByteBuffer).getInt) - - testNumericColumnType[ShortType.type, Short]( - SHORT, - Array.fill(4)(Random.nextInt(Short.MaxValue).asInstanceOf[Short]), - ByteBuffer.allocate(32), - (_: ByteBuffer).putShort(_), - (_: ByteBuffer).getShort) - - testNumericColumnType[LongType.type, Long]( - LONG, - Array.fill(4)(Random.nextLong()), - ByteBuffer.allocate(64), - (_: ByteBuffer).putLong(_), - (_: ByteBuffer).getLong) - - testNumericColumnType[ByteType.type, Byte]( - BYTE, - Array.fill(4)(Random.nextInt(Byte.MaxValue).asInstanceOf[Byte]), - ByteBuffer.allocate(64), - (_: ByteBuffer).put(_), - (_: ByteBuffer).get) - - testNumericColumnType[DoubleType.type, Double]( - DOUBLE, - Array.fill(4)(Random.nextDouble()), - ByteBuffer.allocate(64), - (_: ByteBuffer).putDouble(_), - (_: ByteBuffer).getDouble) - - testNumericColumnType[FloatType.type, Float]( - FLOAT, - Array.fill(4)(Random.nextFloat()), - ByteBuffer.allocate(64), - (_: ByteBuffer).putFloat(_), - (_: ByteBuffer).getFloat) - - test("STRING") { - val buffer = ByteBuffer.allocate(128) - val seq = Array("hello", "world", "spark", "sql") - - seq.map(_.getBytes).foreach { bytes: Array[Byte] => - buffer.putInt(bytes.length).put(bytes) - } + testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) - buffer.rewind() - seq.foreach { s => - assert(s === STRING.extract(buffer)) - } + testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) - buffer.rewind() - seq.foreach(STRING.append(_, buffer)) + testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) - buffer.rewind() - seq.foreach { s => - val length = buffer.getInt - assert(length === s.getBytes.length) + testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + + testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + + testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) + testNativeColumnType[StringType.type]( + STRING, + (buffer: ByteBuffer, string: String) => { + val bytes = string.getBytes() + buffer.putInt(bytes.length).put(string.getBytes) + }, + (buffer: ByteBuffer) => { + val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) - assert(s === new String(bytes)) - } - } - - test("BINARY") { - val buffer = ByteBuffer.allocate(128) - val seq = Array.fill(4) { - val bytes = new Array[Byte](4) - Random.nextBytes(bytes) - bytes - } + new String(bytes) + }) - seq.foreach { bytes => + testColumnType[BinaryType.type, Array[Byte]]( + BINARY, + (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) - } - - buffer.rewind() - seq.foreach { b => - assert(b === BINARY.extract(buffer)) - } - - buffer.rewind() - seq.foreach(BINARY.append(_, buffer)) - - buffer.rewind() - seq.foreach { b => - val length = buffer.getInt - assert(length === b.length) - + }, + (buffer: ByteBuffer) => { + val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) - assert(b === bytes) - } - } + bytes + }) test("GENERIC") { val buffer = ByteBuffer.allocate(512) @@ -177,43 +124,58 @@ class ColumnTypeSuite extends FunSuite { val length = buffer.getInt() assert(length === serializedObj.length) - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - assert(obj === SparkSqlSerializer.deserialize(bytes)) + expectResult(obj, "Deserialized object didn't equal to the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + SparkSqlSerializer.deserialize(bytes) + } buffer.rewind() buffer.putInt(serializedObj.length).put(serializedObj) - buffer.rewind() - assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer))) + expectResult(obj, "Deserialized object didn't equal to the original object") { + buffer.rewind() + SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + } + } + + def testNativeColumnType[T <: NativeType]( + columnType: NativeColumnType[T], + putter: (ByteBuffer, T#JvmType) => Unit, + getter: (ByteBuffer) => T#JvmType) { + + testColumnType[T, T#JvmType](columnType, putter, getter) } - def testNumericColumnType[T <: DataType, JvmType]( + def testColumnType[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], - seq: Seq[JvmType], - buffer: ByteBuffer, putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType) { - val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$") + val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) + val seq = (0 until 4).map(_ => makeRandomValue(columnType)) - test(s"$columnTypeName.extract") { + test(s"$columnType.extract") { buffer.rewind() seq.foreach(putter(buffer, _)) buffer.rewind() - seq.foreach { i => - assert(i === columnType.extract(buffer)) + seq.foreach { expected => + assert( + expected === columnType.extract(buffer), + "Extracted value didn't equal to the original one") } } - test(s"$columnTypeName.append") { + test(s"$columnType.append") { buffer.rewind() seq.foreach(columnType.append(_, buffer)) buffer.rewind() - seq.foreach { i => - assert(i === getter(buffer)) + seq.foreach { expected => + assert( + expected === getter(buffer), + "Extracted value didn't equal to the original one") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala index 928851a385d41..70b2e851737f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.sql.execution.SparkLogicalPlan import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{TestData, DslQuerySuite} -class ColumnarQuerySuite extends DslQuerySuite { +class ColumnarQuerySuite extends QueryTest { import TestData._ import TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala deleted file mode 100644 index ddcdede8d1a4a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar - -import scala.util.Random - -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow - -// TODO Enrich test data -object ColumnarTestData { - object GenericMutableRow { - def apply(values: Any*) = { - val row = new GenericMutableRow(values.length) - row.indices.foreach { i => - row(i) = values(i) - } - row - } - } - - def randomBytes(length: Int) = { - val bytes = new Array[Byte](length) - Random.nextBytes(bytes) - bytes - } - - val nonNullRandomRow = GenericMutableRow( - Random.nextInt(), - Random.nextLong(), - Random.nextFloat(), - Random.nextDouble(), - Random.nextBoolean(), - Random.nextInt(Byte.MaxValue).asInstanceOf[Byte], - Random.nextInt(Short.MaxValue).asInstanceOf[Short], - Random.nextString(Random.nextInt(64)), - randomBytes(Random.nextInt(64)), - Map(Random.nextInt() -> Random.nextString(4))) - - val nullRow = GenericMutableRow(Seq.fill(10)(null): _*) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala new file mode 100644 index 0000000000000..04bdc43d95328 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import scala.collection.immutable.HashSet +import scala.util.Random + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types.{DataType, NativeType} + +object ColumnarTestUtils { + def makeNullRow(length: Int) = { + val row = new GenericMutableRow(length) + (0 until length).foreach(row.setNullAt) + row + } + + def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = { + def randomBytes(length: Int) = { + val bytes = new Array[Byte](length) + Random.nextBytes(bytes) + bytes + } + + (columnType match { + case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte + case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort + case INT => Random.nextInt() + case LONG => Random.nextLong() + case FLOAT => Random.nextFloat() + case DOUBLE => Random.nextDouble() + case STRING => Random.nextString(Random.nextInt(32)) + case BOOLEAN => Random.nextBoolean() + case BINARY => randomBytes(Random.nextInt(32)) + case _ => + // Using a random one-element map instead of an arbitrary object + Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) + }).asInstanceOf[JvmType] + } + + def makeRandomValues( + head: ColumnType[_ <: DataType, _], + tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) + + def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = { + columnTypes.map(makeRandomValue(_)) + } + + def makeUniqueRandomValues[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType], + count: Int): Seq[JvmType] = { + + Iterator.iterate(HashSet.empty[JvmType]) { set => + set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() + }.drop(count).next().toSeq + } + + def makeRandomRow( + head: ColumnType[_ <: DataType, _], + tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail) + + def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = { + val row = new GenericMutableRow(columnTypes.length) + makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => + row(index) = value + } + row + } + + def makeUniqueValuesAndSingleValueRows[T <: NativeType]( + columnType: NativeColumnType[T], + count: Int) = { + + val values = makeUniqueRandomValues(columnType, count) + val rows = values.map { value => + val row = new GenericMutableRow(1) + row(0) = value + row + } + + (values, rows) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index d413d483f4e7e..4a21eb6201a69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -17,12 +17,29 @@ package org.apache.spark.sql.columnar +import java.nio.ByteBuffer + import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType + import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types.DataType + +class TestNullableColumnAccessor[T <: DataType, JvmType]( + buffer: ByteBuffer, + columnType: ColumnType[T, JvmType]) + extends BasicColumnAccessor(buffer, columnType) + with NullableColumnAccessor + +object TestNullableColumnAccessor { + def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + // Skips the column type ID + buffer.getInt() + new TestNullableColumnAccessor(buffer, columnType) + } +} class NullableColumnAccessorSuite extends FunSuite { - import ColumnarTestData._ + import ColumnarTestUtils._ Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { testNullableColumnAccessor(_) @@ -30,30 +47,32 @@ class NullableColumnAccessorSuite extends FunSuite { def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val nullRow = makeNullRow(1) - test(s"$typeName accessor: empty column") { - val builder = ColumnBuilder(columnType.typeId, 4) - val accessor = ColumnAccessor(builder.build()) + test(s"Nullable $typeName column accessor: empty column") { + val builder = TestNullableColumnBuilder(columnType) + val accessor = TestNullableColumnAccessor(builder.build(), columnType) assert(!accessor.hasNext) } - test(s"$typeName accessor: access null values") { - val builder = ColumnBuilder(columnType.typeId, 4) + test(s"Nullable $typeName column accessor: access null values") { + val builder = TestNullableColumnBuilder(columnType) + val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => - builder.appendFrom(nonNullRandomRow, columnType.typeId) - builder.appendFrom(nullRow, columnType.typeId) + builder.appendFrom(randomRow, 0) + builder.appendFrom(nullRow, 0) } - val accessor = ColumnAccessor(builder.build()) + val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) (0 until 4).foreach { _ => accessor.extractTo(row, 0) - assert(row(0) === nonNullRandomRow(columnType.typeId)) + assert(row(0) === randomRow(0)) accessor.extractTo(row, 0) - assert(row(0) === null) + assert(row.isNullAt(0)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 5222a47e1ab87..d9d1e1bfddb75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -19,63 +19,71 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer +class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + with NullableColumnBuilder + +object TestNullableColumnBuilder { + def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + val builder = new TestNullableColumnBuilder(columnType) + builder.initialize(initialSize) + builder + } +} + class NullableColumnBuilderSuite extends FunSuite { - import ColumnarTestData._ + import ColumnarTestUtils._ Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { testNullableColumnBuilder(_) } def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { - val columnBuilder = ColumnBuilder(columnType.typeId) val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") { - columnBuilder.initialize(4) - + val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt === 0) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(0, "Wrong null count")(buffer.getInt()) assert(!buffer.hasRemaining) } test(s"$typeName column builder: buffer size auto growth") { - columnBuilder.initialize(4) + val columnBuilder = TestNullableColumnBuilder(columnType) + val randomRow = makeRandomRow(columnType) - (0 until 4) foreach { _ => - columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId) + (0 until 4).foreach { _ => + columnBuilder.appendFrom(randomRow, 0) } val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt() === 0) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(0, "Wrong null count")(buffer.getInt()) } test(s"$typeName column builder: null values") { - columnBuilder.initialize(4) + val columnBuilder = TestNullableColumnBuilder(columnType) + val randomRow = makeRandomRow(columnType) + val nullRow = makeNullRow(1) - (0 until 4) foreach { _ => - columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId) - columnBuilder.appendFrom(nullRow, columnType.typeId) + (0 until 4).foreach { _ => + columnBuilder.appendFrom(randomRow, 0) + columnBuilder.appendFrom(nullRow, 0) } val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt() === 4) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(4, "Wrong null count")(buffer.getInt()) + // For null positions - (1 to 7 by 2).foreach(i => assert(buffer.getInt() === i)) + (1 to 7 by 2).foreach(expectResult(_, "Wrong null position")(buffer.getInt())) // For non-null values (0 until 4).foreach { _ => @@ -84,7 +92,8 @@ class NullableColumnBuilderSuite extends FunSuite { } else { columnType.extract(buffer) } - assert(actual === nonNullRandomRow(columnType.typeId)) + + assert(actual === randomRow(0), "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala new file mode 100644 index 0000000000000..184691ab5b46a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + +class DictionaryEncodingSuite extends FunSuite { + testDictionaryEncoding(new IntColumnStats, INT) + testDictionaryEncoding(new LongColumnStats, LONG) + testDictionaryEncoding(new StringColumnStats, STRING) + + def testDictionaryEncoding[T <: NativeType]( + columnStats: NativeColumnStats[T], + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + def buildDictionary(buffer: ByteBuffer) = { + (0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap + } + + test(s"$DictionaryEncoding with $typeName: simple case") { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2) + + builder.initialize(0) + builder.appendFrom(rows(0), 0) + builder.appendFrom(rows(1), 0) + builder.appendFrom(rows(0), 0) + builder.appendFrom(rows(1), 0) + + val buffer = builder.build() + val headerSize = CompressionScheme.columnHeaderSize(buffer) + // 4 extra bytes for dictionary size + val dictionarySize = 4 + values.map(columnType.actualSize).sum + // 4 `Short`s, 2 bytes each + val compressedSize = dictionarySize + 2 * 4 + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + // Skips column header + buffer.position(headerSize) + expectResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val dictionary = buildDictionary(buffer) + Array[Short](0, 1).foreach { i => + expectResult(i, "Wrong dictionary entry")(dictionary(values(i))) + } + + Array[Short](0, 1, 0, 1).foreach { + expectResult(_, "Wrong column element value")(buffer.getShort()) + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = new DictionaryEncoding.Decoder[T](buffer, columnType) + + Array[Short](0, 1, 0, 1).foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + + assert(!decoder.hasNext) + } + } + + test(s"$DictionaryEncoding: overflow") { + val builder = TestCompressibleColumnBuilder(new IntColumnStats, INT, DictionaryEncoding) + builder.initialize(0) + + (0 to Short.MaxValue).foreach { n => + val row = new GenericMutableRow(1) + row.setInt(0, n) + builder.appendFrom(row, 0) + } + + withClue("Dictionary overflowed, encoding should fail") { + intercept[Throwable] { + builder.build() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala new file mode 100644 index 0000000000000..2089ad120d4f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.columnar.ColumnarTestUtils._ + +class RunLengthEncodingSuite extends FunSuite { + testRunLengthEncoding(new BooleanColumnStats, BOOLEAN) + testRunLengthEncoding(new ByteColumnStats, BYTE) + testRunLengthEncoding(new ShortColumnStats, SHORT) + testRunLengthEncoding(new IntColumnStats, INT) + testRunLengthEncoding(new LongColumnStats, LONG) + testRunLengthEncoding(new StringColumnStats, STRING) + + def testRunLengthEncoding[T <: NativeType]( + columnStats: NativeColumnStats[T], + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + test(s"$RunLengthEncoding with $typeName: simple case") { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2) + + builder.initialize(0) + builder.appendFrom(rows(0), 0) + builder.appendFrom(rows(0), 0) + builder.appendFrom(rows(1), 0) + builder.appendFrom(rows(1), 0) + + val buffer = builder.build() + val headerSize = CompressionScheme.columnHeaderSize(buffer) + // 4 extra bytes each run for run length + val compressedSize = values.map(columnType.actualSize(_) + 4).sum + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + // Skips column header + buffer.position(headerSize) + expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + Array(0, 1).foreach { i => + expectResult(values(i), "Wrong column element value")(columnType.extract(buffer)) + expectResult(2, "Wrong run length")(buffer.getInt()) + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType) + + Array(0, 0, 1, 1).foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + + assert(!decoder.hasNext) + } + + test(s"$RunLengthEncoding with $typeName: run length == 1") { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, 2) + + builder.initialize(0) + builder.appendFrom(rows(0), 0) + builder.appendFrom(rows(1), 0) + + val buffer = builder.build() + val headerSize = CompressionScheme.columnHeaderSize(buffer) + // 4 bytes each run for run length + val compressedSize = values.map(columnType.actualSize(_) + 4).sum + // 4 bytes for compression scheme type ID + expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + // Skips column header + buffer.position(headerSize) + expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + Array(0, 1).foreach { i => + expectResult(values(i), "Wrong column element value")(columnType.extract(buffer)) + expectResult(1, "Wrong run length")(buffer.getInt()) + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType) + + Array(0, 1).foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + + assert(!decoder.hasNext) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala new file mode 100644 index 0000000000000..e0ec812863dcf --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar._ + +class TestCompressibleColumnBuilder[T <: NativeType]( + override val columnStats: NativeColumnStats[T], + override val columnType: NativeColumnType[T], + override val schemes: Seq[CompressionScheme]) + extends NativeColumnBuilder(columnStats, columnType) + with NullableColumnBuilder + with CompressibleColumnBuilder[T] { + + override protected def isWorthCompressing(encoder: Encoder) = true +} + +object TestCompressibleColumnBuilder { + def apply[T <: NativeType]( + columnStats: NativeColumnStats[T], + columnType: NativeColumnType[T], + scheme: CompressionScheme) = { + + new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme)) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 71caa709afca6..ea1733b3614e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -30,6 +30,9 @@ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext +// Implicits +import org.apache.spark.sql.test.TestSQLContext._ + class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll { override def beforeAll() { ParquetTestData.writeFile() @@ -39,6 +42,22 @@ class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll { ParquetTestData.testFile.delete() } + test("self-join parquet files") { + val x = ParquetTestData.testData.subquery('x) + val y = ParquetTestData.testData.subquery('y) + val query = x.join(y).where("x.myint".attr === "y.myint".attr) + + // Check to make sure that the attributes from either side of the join have unique expression + // ids. + query.queryExecution.analyzed.output.filter(_.name == "myint") match { + case Seq(i1, i2) if(i1.exprId == i2.exprId) => + fail(s"Duplicate expression IDs found in query plan: $query") + case Seq(_, _) => // All good + } + + // TODO: We can't run this query as it NPEs + } + test("Import of simple Parquet file") { val result = getRDD(ParquetTestData.testData).collect() assert(result.size === 15) diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 7b5ea98f27ff5..63f592cb4b441 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -30,6 +30,17 @@ jar Spark Project Hive http://spark.apache.org/ + + + yarn-alpha + + + org.apache.avro + avro + + + + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fc5057b73fe24..46febbfad037d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -188,13 +188,13 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveContext = self override val strategies: Seq[Strategy] = Seq( - TopK, + TakeOrdered, ParquetOperations, HiveTableScans, DataSinks, Scripts, PartialAggregation, - SparkEquiInnerJoin, + HashJoin, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4f8353666a12b..29834a11f41dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -141,6 +141,13 @@ class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { */ override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ??? + + /** + * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. + * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. + */ + override def unregisterTable( + databaseName: Option[String], tableName: String): Unit = ??? } object HiveMetastoreTypes extends RegexParsers { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f4b61381f9a27..b2b03bc790fcc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -529,7 +529,7 @@ object HiveQl { val withLimit = limitClause.map(l => nodeToExpr(l.getChildren.head)) - .map(StopAfter(_, withSort)) + .map(Limit(_, withSort)) .getOrElse(withSort) // TOK_INSERT_INTO means to add files to the table. @@ -602,7 +602,7 @@ object HiveQl { case Token("TOK_TABLESPLITSAMPLE", Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => - StopAfter(Literal(count.toInt), relation) + Limit(Literal(count.toInt), relation) case Token("TOK_TABLESPLITSAMPLE", Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => @@ -662,7 +662,7 @@ object HiveQl { // worth the number of hacks that will be required to implement it. Namely, we need to add // some sort of mapped star expansion that would expand all child output row to be similarly // named output expressions where some aggregate expression has been applied (i.e. First). - ??? /// Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + ??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) case Token(allJoinTokens(joinToken), relation1 :: @@ -847,12 +847,9 @@ object HiveQl { case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("LIKE", left :: right:: Nil) => - UnresolvedFunction("LIKE", Seq(nodeToExpr(left), nodeToExpr(right))) - case Token("RLIKE", left :: right:: Nil) => - UnresolvedFunction("RLIKE", Seq(nodeToExpr(left), nodeToExpr(right))) - case Token("REGEXP", left :: right:: Nil) => - UnresolvedFunction("REGEXP", Seq(nodeToExpr(left), nodeToExpr(right))) + case Token("LIKE", left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) + case Token("RLIKE", left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token("REGEXP", left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => IsNotNull(nodeToExpr(child)) case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index ca5311344615f..0da5eb754cb3f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -94,7 +94,7 @@ class HadoopTableReader(@transient _tableDesc: TableDesc, @transient sc: HiveCon val tablePath = hiveTable.getPath val inputPathStr = applyFilterIfNeeded(tablePath, filterOpt) - //logDebug("Table input: %s".format(tablePath)) + // logDebug("Table input: %s".format(tablePath)) val ifc = hiveTable.getInputFormatClass .asInstanceOf[java.lang.Class[InputFormat[Writable, Writable]]] val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc) diff --git a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d deleted file mode 100644 index 5f4de85940513..0000000000000 --- a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d +++ /dev/null @@ -1 +0,0 @@ -0 val_0 \ No newline at end of file diff --git a/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d new file mode 100644 index 0000000000000..016f64cc26f2a --- /dev/null +++ b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d @@ -0,0 +1 @@ +0 val_0 diff --git a/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala new file mode 100644 index 0000000000000..68d45e53cdf26 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.execution.HiveComparisonTest + +class CachedTableSuite extends HiveComparisonTest { + TestHive.loadTestTable("src") + + test("cache table") { + TestHive.cacheTable("src") + } + + createQueryTest("read from cached table", + "SELECT * FROM src LIMIT 1") + + test("check that table is cached and uncache") { + TestHive.table("src").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + TestHive.uncacheTable("src") + } + + createQueryTest("read from uncached table", + "SELECT * FROM src LIMIT 1") + + test("make sure table is uncached") { + TestHive.table("src").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestHive.uncacheTable("src") + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index d77900ddc950c..40c4e23f90fb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -48,7 +48,7 @@ class HiveResolutionSuite extends HiveComparisonTest { createQueryTest("attr", "SELECT key FROM src a ORDER BY key LIMIT 1") - createQueryTest("alias.*", + createQueryTest("alias.star", "SELECT a.* FROM src a ORDER BY key LIMIT 1") test("case insensitivity with scala reflection") { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index fde46705d89fb..d3339063cc079 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -153,7 +153,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") - //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + + // assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + // " is very low") assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index a3ee1213200e9..6b93f723c3e56 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -153,7 +153,6 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging } catch { case ie: InterruptedException => logInfo("Receiving thread interrupted") - //println("Receiving thread interrupted") case e: Exception => stopOnError(e) } @@ -167,7 +166,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging def stop() { receivingThread.interrupt() onStop() - //TODO: terminate the actor + // TODO: terminate the actor } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index ca0a8ae47864d..b334d68bf9910 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -78,7 +78,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( override def checkpoint(interval: Duration): DStream[(K, V)] = { super.checkpoint(interval) - //reducedStream.checkpoint(interval) + // reducedStream.checkpoint(interval) this } @@ -128,7 +128,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Cogroup the reduced RDDs and merge the reduced values val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(K, _)]]], partitioner) - //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ + // val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ val numOldValues = oldRDDs.size val numNewValues = newRDDs.size diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 9d8889b655356..5f7d3ba26c656 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -64,7 +64,6 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime) Some(stateRDD) } case None => { // If parent RDD does not exist @@ -97,11 +96,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val groupedRDD = parentRDD.groupByKey(partitioner) val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) - //logDebug("Generating state RDD for time " + validTime + " (first)") + // logDebug("Generating state RDD for time " + validTime + " (first)") Some(sessionRDD) } case None => { // If parent RDD does not exist, then nothing to do! - //logDebug("Not generating state RDD (no previous state, no parent)") + // logDebug("Not generating state RDD (no previous state, no parent)") None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index 6551535f876a1..ed096a0b94517 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -136,7 +136,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { receiverInfo -= streamId logError("De-registered receiver for network stream " + streamId + " with message " + msg) - //TODO: Do something about the corresponding NetworkInputDStream + // TODO: Do something about the corresponding NetworkInputDStream } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 0784e562ac719..25739956cb889 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -252,7 +252,7 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() // Create files and advance manual clock to process them - //var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + // var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] Thread.sleep(1000) for (i <- Seq(1, 2, 3)) { Files.write(i + "\n", new File(testDir, i.toString), Charset.forName("UTF-8")) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 723ea18e91dbf..3309f9abb304d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -154,7 +154,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", - StorageLevel.MEMORY_AND_DISK) //Had to pass the local value of port to prevent from closing over entire scope + // Had to pass the local value of port to prevent from closing over entire scope + StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) def output = outputBuffer.flatMap(x => x) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 7d18a0fcf7ba8..9ebf7b484f421 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -36,8 +36,9 @@ class RateLimitedOutputStreamSuite extends FunSuite { val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } - // We accept anywhere from 4.0 to 4.99999 seconds since the value is rounded down. - assert(SECONDS.convert(elapsedNs, NANOSECONDS) === 4) + val seconds = SECONDS.convert(elapsedNs, NANOSECONDS) + assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.") + assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.") assert(underlying.toString("UTF-8") === data) } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 71a64ecf5879a..0179b0600c61f 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -167,6 +167,9 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa object Client { def main(argStrings: Array[String]) { + println("WARNING: This client is deprecated and will be removed in a future version of Spark.") + println("Use ./bin/spark-submit with \"--master yarn\"") + // Set an env variable indicating we are running in YARN mode. // Note that anything with SPARK prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index c565f2dde24fc..3e4c739e34fe9 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -63,7 +63,10 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { userClass = value args = tail - case ("--args") :: value :: tail => + case ("--args" | "--arg") :: value :: tail => + if (args(0) == "--args") { + println("--args is deprecated. Use --arg instead.") + } userArgsBuffer += value args = tail @@ -146,8 +149,8 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { "Options:\n" + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + " --class CLASS_NAME Name of your application's main class (required)\n" + - " --args ARGS Arguments to be passed to your application's main class.\n" + - " Mutliple invocations are possible, each will be passed in order.\n" + + " --arg ARGS Argument to be passed to your application's main class.\n" + + " Multiple invocations are possible, each will be passed in order.\n" + " --num-executors NUM Number of executors to start (Default: 2)\n" + " --executor-cores NUM Number of cores for the executors (Default: 1).\n" + " --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)\n" + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 57e5761cba896..6568003bf1008 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -139,7 +139,6 @@ trait ClientBase extends Logging { } else if (srcHost != null && dstHost == null) { return false } - //check for ports if (srcUri.getPort() != dstUri.getPort()) { false } else { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 68cda0f1c9f8b..9b7f1fca96c6d 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -157,7 +157,7 @@ class ClientDistributedCacheManager() extends Logging { def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = { val fs = FileSystem.get(uri, conf) val current = new Path(uri.getPath()) - //the leaf level file should be readable by others + // the leaf level file should be readable by others if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) { return false } @@ -177,7 +177,7 @@ class ClientDistributedCacheManager() extends Logging { statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { - //the subdirs in the path should have execute permissions for others + // the subdirs in the path should have execute permissions for others if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) { return false } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index d1f13e3c369ed..161918859e7c4 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -33,11 +33,12 @@ private[spark] class YarnClientSchedulerBackend( var client: Client = null var appId: ApplicationId = null - private[spark] def addArg(optionName: String, optionalParam: String, arrayBuf: ArrayBuffer[String]) { - Option(System.getenv(optionalParam)) foreach { - optParam => { - arrayBuf += (optionName, optParam) - } + private[spark] def addArg(optionName: String, envVar: String, sysProp: String, + arrayBuf: ArrayBuffer[String]) { + if (System.getProperty(sysProp) != null) { + arrayBuf += (optionName, System.getProperty(sysProp)) + } else if (System.getenv(envVar) != null) { + arrayBuf += (optionName, System.getenv(envVar)) } } @@ -56,22 +57,24 @@ private[spark] class YarnClientSchedulerBackend( "--am-class", "org.apache.spark.deploy.yarn.ExecutorLauncher" ) - // process any optional arguments, use the defaults already defined in ClientArguments - // if things aren't specified - Map("SPARK_MASTER_MEMORY" -> "--driver-memory", - "SPARK_DRIVER_MEMORY" -> "--driver-memory", - "SPARK_WORKER_INSTANCES" -> "--num-executors", - "SPARK_WORKER_MEMORY" -> "--executor-memory", - "SPARK_WORKER_CORES" -> "--executor-cores", - "SPARK_EXECUTOR_INSTANCES" -> "--num-executors", - "SPARK_EXECUTOR_MEMORY" -> "--executor-memory", - "SPARK_EXECUTOR_CORES" -> "--executor-cores", - "SPARK_YARN_QUEUE" -> "--queue", - "SPARK_YARN_APP_NAME" -> "--name", - "SPARK_YARN_DIST_FILES" -> "--files", - "SPARK_YARN_DIST_ARCHIVES" -> "--archives") - .foreach { case (optParam, optName) => addArg(optName, optParam, argsArrayBuf) } - + // process any optional arguments, given either as environment variables + // or system properties. use the defaults already defined in ClientArguments + // if things aren't specified. system properties override environment + // variables. + List(("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), + ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), + ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.worker.instances"), + ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), + ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), + ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), + ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), + ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--name", "SPARK_YARN_APP_NAME", "spark.app.name"), + ("--files", "SPARK_YARN_DIST_FILES", "spark.yarn.dist.files"), + ("--archives", "SPARK_YARN_DIST_ARCHIVES", "spark.yarn.dist.archives")) + .foreach { case (optName, envVar, sysProp) => addArg(optName, envVar, sysProp, argsArrayBuf) } + logDebug("ClientArguments called with: " + argsArrayBuf) val args = new ClientArguments(argsArrayBuf.toArray, conf) client = new Client(args, conf) diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 458df4fa3cd99..80b57d1355a3a 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -99,7 +99,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None) assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) - //add another one and verify both there and order correct + // add another one and verify both there and order correct val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 837b7e12cb0de..77eb1276a0c4e 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -173,6 +173,9 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa object Client { def main(argStrings: Array[String]) { + println("WARNING: This client is deprecated and will be removed in a future version of Spark.") + println("Use ./bin/spark-submit with \"--master yarn\"") + // Set an env variable indicating we are running in YARN mode. // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes - // see Client#setupLaunchEnv().