diff --git a/core/pom.xml b/core/pom.xml index eb6cc4d3105e9..e4c32eff0cd77 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -150,7 +150,7 @@ json4s-jackson_${scala.binary.version} 3.2.6 diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 6e1736f6fbc23..e1a1f209c9282 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -18,13 +18,14 @@ package org.apache.spark.ui import java.net.{InetSocketAddress, URL} +import javax.servlet.DispatcherType import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.annotation.tailrec import scala.util.{Failure, Success, Try} import scala.xml.Node -import org.eclipse.jetty.server.{DispatcherType, Server} +import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.util.thread.QueuedThreadPool diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index 4d8b01dbe6e1b..a7b24ff695214 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -84,7 +84,7 @@ private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListe override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized { val rddInfo = stageSubmitted.stageInfo.rddInfo - _rddInfoMap(rddInfo.id) = rddInfo + _rddInfoMap.getOrElseUpdate(rddInfo.id, rddInfo) } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index 2437a98672177..38becda0eae92 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -4,7 +4,7 @@ run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.3 \ +$ SCALA_VERSION=2.10.4 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 52c367d9b030d..fa2f02dfecc75 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,7 +35,7 @@ RELEASE_KEY = "9E4FE3AF" RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/" RELEASE_VERSION = "1.0.0" -SCALA_VERSION = "2.10.3" +SCALA_VERSION = "2.10.4" SCALA_BINARY_VERSION = "2.10" ## diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index e543db6143e4d..5956d59130fbf 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update # install a few other useful packages plus Open Jdk 7 RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server -ENV SCALA_VERSION 2.10.3 +ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docs/_config.yml b/docs/_config.yml index aa5a5adbc1743..d585b8c5ea763 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -6,7 +6,7 @@ markdown: kramdown SPARK_VERSION: 1.0.0-SNAPSHOT SPARK_VERSION_SHORT: 1.0.0 SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.3" +SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.13.0 SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index d8657c4bc7096..982514391ac00 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -61,7 +61,7 @@ The command to launch the Spark application on the cluster is as follows: SPARK_JAR= ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar \ --class \ - --args \ + --arg \ --num-executors \ --driver-memory \ --executor-memory \ @@ -72,7 +72,7 @@ The command to launch the Spark application on the cluster is as follows: --files \ --archives -For example: +To pass multiple arguments the "arg" option can be specified multiple times. For example: # Build the Spark assembly JAR and the Spark examples JAR $ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly @@ -85,7 +85,8 @@ For example: ./bin/spark-class org.apache.spark.deploy.yarn.Client \ --jar examples/target/scala-{{site.SCALA_BINARY_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ --class org.apache.spark.examples.SparkPi \ - --args yarn-cluster \ + --arg yarn-cluster \ + --arg 5 \ --num-executors 3 \ --driver-memory 4g \ --executor-memory 2g \ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala new file mode 100644 index 0000000000000..33205b919db8f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -0,0 +1,1150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.util.control.Breaks._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * A class that implements a decision tree algorithm for classification and regression. It + * supports both continuous and categorical features. + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + */ +class DecisionTree private(val strategy: Strategy) extends Serializable with Logging { + + /** + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + + // Cache input RDD for speedup during multiple passes. + input.cache() + logDebug("algo = " + strategy.algo) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + logDebug("numSplits = " + bins(0).length) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + // the max number of nodes possible given the depth of the tree + val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + // Initialize an array to hold filters applied to points for each node. + val filters = new Array[List[Filter]](maxNumNodes) + // The filter at the top node is an empty list. + filters(0) = List() + // Initialize an array to hold parent impurity calculations for each node. + val parentImpurities = new Array[Double](maxNumNodes) + // dummy value for top node (updated during first split calculation) + val nodes = new Array[Node](maxNumNodes) + + + /* + * The main idea here is to perform level-wise training of the decision tree nodes thus + * reducing the passes over the data from l to log2(l) where l is the total number of nodes. + * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + * the sample is only used for the split calculation at the node if the sampled would have + * still survived the filters of the parent nodes. + */ + + // TODO: Convert for loop to while loop + breakable { + for (level <- 0 until maxDepth) { + + logDebug("#####################################") + logDebug("level = " + level) + logDebug("#####################################") + + // Find best split for all nodes at a level. + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, + level, filters, splits, bins) + + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + // Extract info for nodes at the current level. + extractNodeInfo(nodeSplitStats, level, index, nodes) + // Extract info for nodes at the next lower level. + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, + filters) + logDebug("final best split = " + nodeSplitStats._1) + } + require(scala.math.pow(2, level) == splitsStatsForLevel.length) + // Check whether all the nodes at the current level at leaves. + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) + logDebug("all leaf = " + allLeaf) + if (allLeaf) break // no more tree construction + } + } + + // Initialize the top or root node of the tree. + val topNode = nodes(0) + // Build the full tree using the node info calculated in the level-wise best split calculations. + topNode.build(nodes) + + new DecisionTreeModel(topNode, strategy.algo) + } + + /** + * Extract the decision tree node information for the given tree level and node index + */ + private def extractNodeInfo( + nodeSplitStats: (Split, InformationGainStats), + level: Int, + index: Int, + nodes: Array[Node]): Unit = { + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node + } + + /** + * Extract the decision tree node information for the children of the node + */ + private def extractInfoForLowerLevels( + level: Int, + index: Int, + maxDepth: Int, + nodeSplitStats: (Split, InformationGainStats), + parentImpurities: Array[Double], + filters: Array[List[Filter]]): Unit = { + // 0 corresponds to the left child node and 1 corresponds to the right child node. + // TODO: Convert to while loop + for (i <- 0 to 1) { + // Calculate the index of the node from the node level and the index at the current level. + val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + if (level < maxDepth - 1) { + val impurity = if (i == 0) { + nodeSplitStats._2.leftImpurity + } else { + nodeSplitStats._2.rightImpurity + } + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + // noting the parent impurities + parentImpurities(nodeIndex) = impurity + // noting the parents filters for the child nodes + val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) + filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) + for (filter <- filters(nodeIndex)) { + logDebug("Filter = " + filter) + } + } + } + } +} + +object DecisionTree extends Serializable with Logging { + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. The parameters for the algorithm are specified using the strategy parameter. + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int): DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The decision tree method supports binary classification and + * regression. For the binary classification, the label for each instance should either be 0 or + * 1 to denote the two classes. The method also supports categorical features inputs where the + * number of categories can specified using the categoricalFeaturesInfo option. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data for DecisionTree + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + private val InvalidBinIndex = -1 + + /** + * Returns an array of optimal splits for all nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ + protected[tree] def findBestSplits( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + + /* + * The high-level description for the best split optimizations are noted here. + * + * *Level-wise training* + * We perform bin calculations for all nodes at the given level to avoid making multiple + * passes over the data. Thus, for a slightly increased computation and storage cost we save + * several iterations over the data especially at higher levels of the decision tree. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * is ordered (read ordering for categorical variables in the findSplitsBins method), + * we exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // common calculations for multiple nested methods + val numNodes = scala.math.pow(2, level).toInt + logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. + val numFeatures = input.first().features.length + logDebug("numFeatures = " + numFeatures) + val numBins = bins(0).length + logDebug("numBins = " + numBins) + + /** Find the filters used before reaching the current code. */ + def findParentFilters(nodeIndex: Int): List[Filter] = { + if (level == 0) { + List[Filter]() + } else { + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + filters(nodeFilterIndex) + } + } + + /** + * Find whether the sample is valid input for the current node, i.e., whether it passes through + * all the filters for the current node. + */ + def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + // leaf + if ((level > 0) & (parentFilters.length == 0)) { + return false + } + + // Apply each filter and check sample validity. Return false when invalid condition found. + for (filter <- parentFilters) { + val features = labeledPoint.features + val featureIndex = filter.split.feature + val threshold = filter.split.threshold + val comparison = filter.comparison + val categories = filter.split.categories + val isFeatureContinuous = filter.split.featureType == Continuous + val feature = features(featureIndex) + if (isFeatureContinuous) { + comparison match { + case -1 => if (feature > threshold) return false + case 1 => if (feature <= threshold) return false + } + } else { + val containsFeature = categories.contains(feature) + comparison match { + case -1 => if (!containsFeature) return false + case 1 => if (containsFeature) return false + } + + } + } + + // Return true when the sample is valid for all filters. + true + } + + /** + * Find bin for one feature. + */ + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) & (highThreshold >= feature)){ + return mid + } + else if (lowThreshold >= feature) { + right = mid - 1 + } + else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature. + */ + def sequentialBinSearchForCategoricalFeature(): Int = { + val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val bin = bins(featureIndex)(binIndex) + val category = bin.category + val features = labeledPoint.features + if (category == features(featureIndex)) { + return binIndex + } + binIndex += 1 + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = sequentialBinSearchForCategoricalFeature() + if (binIndex == -1){ + throw new UnknownError("no bin was found for categorical variable.") + } + binIndex + } + } + + /** + * Finds bins for all nodes (and all features) at a given level. + * For l nodes, k features the storage is as follows: + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, + * where b_ij is an integer between 0 and numBins - 1. + * Invalid sample is denoted by noting bin for feature 1 as -1. + */ + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + // Calculate bin index and label per feature per node. + val arr = new Array[Double](1 + (numFeatures * numNodes)) + arr(0) = labeledPoint.label + var nodeIndex = 0 + while (nodeIndex < numNodes) { + val parentFilters = findParentFilters(nodeIndex) + // Find out whether the sample qualifies for the particular node. + val sampleValid = isSampleValid(parentFilters, labeledPoint) + val shift = 1 + numFeatures * nodeIndex + if (!sampleValid) { + // Mark one bin as -1 is sufficient. + arr(shift) = InvalidBinIndex + } else { + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + featureIndex += 1 + } + } + nodeIndex += 1 + } + arr + } + + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures * numNodes for classification + */ + def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = 2 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 + label match { + case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 + case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. + * + * @param agg Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for regression + */ + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update count, sum, and sum^2 for one bin. + val aggShift = 3 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition. + */ + def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { + strategy.algo match { + case Classification => classificationBinSeqOp(arr, agg) + case Regression => regressionBinSeqOp(arr, agg) + } + agg + } + + // Calculate bin aggregate length for classification or regression. + val binAggregateLength = strategy.algo match { + case Classification => 2 * numBins * numFeatures * numNodes + case Regression => 3 * numBins * numFeatures * numNodes + } + logDebug("binAggregateLength = " + binAggregateLength) + + /** + * Combines the aggregates from partitions. + * @param agg1 Array containing aggregates from one or more partitions + * @param agg2 Array containing aggregates from one or more partitions + * @return Combined aggregate from agg1 and agg2 + */ + def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { + var index = 0 + val combinedAggregate = new Array[Double](binAggregateLength) + while (index < binAggregateLength) { + combinedAggregate(index) = agg1(index) + agg2(index) + index += 1 + } + combinedAggregate + } + + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + // Calculate bin aggregates. + val binAggregates = { + binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + } + logDebug("binAggregates.length = " + binAggregates.length) + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + def calculateGainForSplit( + leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double): InformationGainStats = { + strategy.algo match { + case Classification => + val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) + val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) + val leftCount = left0Count + left1Count + + val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) + val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) + val rightCount = right0Count + right1Count + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + } + + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (left1Count + right1Count) / (leftCount + rightCount) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + case Regression => + val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) + val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + + val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) + val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + strategy.impurity.calculate(count, sum, sumSquares) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity ,topImpurity, + Double.MinValue, leftSum / leftCount) + } + + val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + } + } + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], + * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) + */ + def extractLeftRightNodeAggregates( + binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + strategy.algo match { + case Classification => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(2 * (numBins - 2)) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = + binData(shift + (2 *(numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + case Regression => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(3 * (numBins - 2)) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3) + leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) + leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } + } + + /** + * Calculates information gain for all nodes splits. + */ + def calculateGainsForAllNodeSplits( + leftNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Double]], + nodeImpurity: Double): Array[Array[InformationGainStats]] = { + val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) + + for (featureIndex <- 0 until numFeatures) { + for (splitIndex <- 0 until numBins - 1) { + gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, + splitIndex, rightNodeAgg, nodeImpurity) + } + } + gains + } + + /** + * Find the best split for a node. + * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param nodeImpurity impurity of the top node + * @return tuple of split and information gain + */ + def binsToBestSplit( + binData: Array[Double], + nodeImpurity: Double): (Split, InformationGainStats) = { + + logDebug("node impurity = " + nodeImpurity) + + // Extract left right node aggregates. + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + + // Calculate gains for all splits. + val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) + + val (bestFeatureIndex,bestSplitIndex, gainStats) = { + // Initialize with infeasible values. + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + // Iterate over features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Iterate over all splits. + var splitIndex = 0 + while (splitIndex < numBins - 1) { + val gainStats = gains(featureIndex)(splitIndex) + if (gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + } + splitIndex += 1 + } + featureIndex += 1 + } + (bestFeatureIndex, bestSplitIndex, bestGainStats) + } + + logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) + + (splits(bestFeatureIndex)(bestSplitIndex), gainStats) + } + + /** + * Get bin data for one node. + */ + def getBinDataForNode(node: Int): Array[Double] = { + strategy.algo match { + case Classification => + val shift = 2 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) + binsForNode + case Regression => + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode + } + } + + // Calculate best splits for all nodes at a given level + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + // Iterating over all nodes at this level + var node = 0 + while (node < numNodes) { + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val binsForNode: Array[Double] = getBinDataForNode(node) + logDebug("nodeImpurityIndex = " + nodeImpurityIndex) + val parentNodeImpurity = parentImpurities(nodeImpurityIndex) + logDebug("node impurity = " + parentNodeImpurity) + bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) + node += 1 + } + + bestSplits + } + + /** + * Returns split and bins for decision tree calculation. + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree + * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + */ + protected[tree] def findSplitsBins( + input: RDD[LabeledPoint], + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() + + // Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.length + + val maxBins = strategy.maxBins + val numBins = if (maxBins <= count) maxBins else count.toInt + logDebug("numBins = " + numBins) + + /* + * TODO: Add a require statement ensuring #bins is always greater than the categories. + * It's a limitation of the current implementation but a reasonable trade-off since features + * with large number of categories get favored over continuous features. + */ + if (strategy.categoricalFeaturesInfo.size > 0) { + val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + require(numBins >= maxCategoriesForFeatures) + } + + // Calculate the number of sample for approximate quantile calculation. + val requiredSamples = numBins*numBins + val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 + logDebug("fraction of data used for calculating quantiles = " + fraction) + + // sampled input for RDD calculation + val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val numSamples = sampledInput.length + + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + + strategy.quantileCalculationStrategy match { + case Sort => + val splits = Array.ofDim[Split](numFeatures, numBins - 1) + val bins = Array.ofDim[Bin](numFeatures, numBins) + + // Find all splits. + + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures){ + // Check whether the feature is continuous. + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins - 1) { + val sampleIndex = (index + 1) * stride.toInt + val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + splits(featureIndex)(index) = split + } + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + require(maxFeatureValue < numBins, "number of categories should be less than number " + + "of bins") + + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, + categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } + } + featureIndex += 1 + } + + // Find all bins. + featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { // Bins for categorical variables are already assigned. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) + for (index <- 1 until numBins - 1){ + val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Continuous, Double.MinValue) + bins(featureIndex)(index) = bin + } + bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) + } + featureIndex += 1 + } + (splits,bins) + case MinMax => + throw new UnsupportedOperationException("minmax not supported yet.") + case ApproxHist => + throw new UnsupportedOperationException("approximate histogram not supported yet.") + } + } + + val usage = """ + Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] + """ + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + val sc = new SparkContext(args(0), "DecisionTree") + + val argList = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]): OptionMap = { + list match { + case Nil => map + case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) + , tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), + tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => logError("Unknown option " + option) + sys.exit(1) + } + } + val options = nextOption(Map(), argList) + logDebug(options.toString()) + + // Load training data. + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + // Identify the type of algorithm. + val algoStr = options.get('algo).get.toString + val algo = algoStr match { + case "Classification" => Classification + case "Regression" => Regression + } + + // Identify the type of impurity. + val impurityStr = options.getOrElse('impurity, + if (algo == Classification) "Gini" else "Variance").toString + val impurity = impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + + val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt + val maxBins = options.getOrElse('maxBins, "100").toString.toInt + + val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val model = DecisionTree.train(trainData, strategy) + + // Load test data. + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + + // Measure algorithm accuracy + if (algo == Classification) { + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + } + + if (algo == Regression) { + val mse = meanSquaredError(model, testData) + logDebug("mean square error = " + mse) + } + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ..., + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = parts.slice(1,parts.length).map(_.toDouble) + LabeledPoint(label, features) + } + } + + // TODO: Port this method to a generic metrics package. + /** + * Calculates the classifier accuracy. + */ + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], + threshold: Double = 0.5): Double = { + def predictedValue(features: Array[Double]) = { + if (model.predict(features) < threshold) 0.0 else 1.0 + } + val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + // TODO: Port this method to a generic metrics package + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md new file mode 100644 index 0000000000000..0fd71aa9735bc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md @@ -0,0 +1,17 @@ +This package contains the default implementation of the decision tree algorithm. + +The decision tree algorithm supports: ++ Binary classification ++ Regression ++ Information loss calculation with entropy and gini for classification and variance for regression ++ Both continuous and categorical features + +# Tree improvements ++ Node model pruning ++ Printing to dot files + +# Future Ensemble Extensions + ++ Random forests ++ Boosting ++ Extremely randomized trees diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala new file mode 100644 index 0000000000000..2dd1f0f27b8f5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum to select the algorithm for the decision tree + */ +object Algo extends Enumeration { + type Algo = Value + val Classification, Regression = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala new file mode 100644 index 0000000000000..09ee0586c58fa --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum to describe whether a feature is "continuous" or "categorical" + */ +object FeatureType extends Enumeration { + type FeatureType = Value + val Continuous, Categorical = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala new file mode 100644 index 0000000000000..2457a480c2a14 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum for selecting the quantile calculation strategy + */ +object QuantileStrategy extends Enumeration { + type QuantileStrategy = Value + val Sort, MinMax, ApproxHist = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala new file mode 100644 index 0000000000000..df565f3eb8859 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ + +/** + * Stores all the configuration options for tree construction + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ +class Strategy ( + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala new file mode 100644 index 0000000000000..b93995fcf9441 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during + * binary classification. + */ +object Entropy extends Impurity { + + def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + /** + * entropy calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return entropy value + */ + def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Entropy.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala new file mode 100644 index 0000000000000..c0407554a91b3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating the + * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] + * during binary classification. + */ +object Gini extends Impurity { + + /** + * Gini coefficient calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return Gini coefficient value + */ + override def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0 * f0 - f1 * f1 + } + } + + def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Gini.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala new file mode 100644 index 0000000000000..a4069063af2ad --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Trait for calculating information gain. + */ +trait Impurity extends Serializable { + + /** + * information calculation for binary classification + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return information value + */ + def calculate(c0 : Double, c1 : Double): Double + + /** + * information calculation for regression + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value + */ + def calculate(count: Double, sum: Double, sumSquares: Double): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala new file mode 100644 index 0000000000000..b74577dcec167 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Class for calculating variance during regression + */ +object Variance extends Impurity { + override def calculate(c0: Double, c1: Double): Double = + throw new UnsupportedOperationException("Variance.calculate") + + /** + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum * sum) / count + squaredLoss / count + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala new file mode 100644 index 0000000000000..a57faa13745f7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +/** + * Used for "binning" the features bins for faster best split calculation. For a continuous + * feature, a bin is determined by a low and a high "split". For a categorical feature, + * the a bin is determined using a single label value (category). + * @param lowSplit signifying the lower threshold for the continuous feature to be + * accepted in the bin + * @param highSplit signifying the upper threshold for the continuous feature to be + * accepted in the bin + * @param featureType type of feature -- categorical or continuous + * @param category categorical label value accepted in the bin + */ +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala new file mode 100644 index 0000000000000..a8bbf21daec01 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD + +/** + * Model to store the decision tree parameters + * @param topNode root node + * @param algo algorithm type -- classification or regression + */ +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features: Array[Double]): Double = { + topNode.predictIfLeaf(features) + } + + /** + * Predict values for the given data set using the model trained. + * + * @param features RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Array[Double]]): RDD[Double] = { + features.map(x => predict(x)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala new file mode 100644 index 0000000000000..ebc9595eafef3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Filter specifying a split and type of comparison to be applied on features + * @param split split specifying the feature index, type and threshold + * @param comparison integer specifying <,=,> + */ +case class Filter(split: Split, comparison: Int) { + // Comparison -1,0,1 signifies <.=,> + override def toString = " split = " + split + "comparison = " + comparison +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala new file mode 100644 index 0000000000000..99bf79cf12e45 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Information gain statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param leftImpurity left node impurity + * @param rightImpurity right node impurity + * @param predict predicted value + */ +class InformationGainStats( + val gain: Double, + val impurity: Double, + val leftImpurity: Double, + val rightImpurity: Double, + val predict: Double) extends Serializable { + + override def toString = { + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala new file mode 100644 index 0000000000000..ea4693c5c2f4e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.Logging +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +/** + * Node in a decision tree + * @param id integer node id + * @param predict predicted value at the node + * @param isLeaf whether the leaf is a node + * @param split split to calculate left and right nodes + * @param leftNode left child + * @param rightNode right child + * @param stats information gain stats + */ +class Node ( + val id: Int, + val predict: Double, + val isLeaf: Boolean, + val split: Option[Split], + var leftNode: Option[Node], + var rightNode: Option[Node], + val stats: Option[InformationGainStats]) extends Serializable with Logging { + + override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + + "split = " + split + ", stats = " + stats + + /** + * build the left node and right nodes if not leaf + * @param nodes array of nodes + */ + def build(nodes: Array[Node]): Unit = { + + logDebug("building node " + id + " at level " + + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("id = " + id + ", split = " + split) + logDebug("stats = " + stats) + logDebug("predict = " + predict) + if (!isLeaf) { + val leftNodeIndex = id*2 + 1 + val rightNodeIndex = id*2 + 2 + leftNode = Some(nodes(leftNodeIndex)) + rightNode = Some(nodes(rightNodeIndex)) + leftNode.get.build(nodes) + rightNode.get.build(nodes) + } + } + + /** + * predict value if node is not leaf + * @param feature feature value + * @return predicted value + */ + def predictIfLeaf(feature: Array[Double]) : Double = { + if (isLeaf) { + predict + } else{ + if (split.get.featureType == Continuous) { + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } else { + if (split.get.categories.contains(feature(split.get.feature))) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala new file mode 100644 index 0000000000000..4e64a81dda74e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType + +/** + * Split applied to a feature + * @param feature feature index + * @param threshold threshold for continuous feature + * @param featureType type of feature -- categorical or continuous + * @param categories accepted values for categorical variables + */ +case class Split( + feature: Int, + threshold: Double, + featureType: FeatureType, + categories: List[Double]){ + + override def toString = + "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + + ", categories = " + categories +} + +/** + * Split with minimum threshold for continuous features. Helps with the smallest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyLowSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MinValue, featureType, List()) + +/** + * Split with maximum threshold for continuous features. Helps with the highest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyHighSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + +/** + * Split with no acceptable feature values for categorical features. Helps with the first bin + * creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +class DummyCategoricalSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala new file mode 100644 index 0000000000000..4349c7000a0ae --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.Filter +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { + + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("split and bin calculation") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + } + + test("split and bin calculation for categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2) === null) + } + + test("split and bin calculations for categorical variables with no sample for one category") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(2.0)) + + assert(splits(0)(3) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 3) + assert(splits(1)(2).categories.contains(1.0)) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(2.0)) + + assert(splits(1)(3) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category === 2.0) + assert(bins(0)(2).lowSplit.categories.length === 2) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).lowSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.length === 3) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.contains(2.0)) + + assert(bins(0)(3) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category === 2.0) + assert(bins(1)(2).lowSplit.categories.length === 2) + assert(bins(1)(2).lowSplit.categories.contains(0.0)) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length === 3) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(2.0)) + + assert(bins(1)(3) === null) + } + + test("classification stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("regression stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("stump with fixed label 0 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + } + + test("stump with fixed label 1 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } + + test("stump with fixed label 0 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 0) + } + + test("stump with fixed label 1 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } +} + +object DecisionTreeSuite { + + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i)) + arr(i) = lp + } + arr + } + + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i)) + arr(i) = lp + } + arr + } + + def generateCategoricalDataPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + arr(i) = new LabeledPoint(1.0,Array(0.0,1.0)) + } else { + arr(i) = new LabeledPoint(0.0,Array(1.0,0.0)) + } + } + arr + } +} diff --git a/pom.xml b/pom.xml index 72acf2b402703..7d58060cba606 100644 --- a/pom.xml +++ b/pom.xml @@ -110,7 +110,7 @@ 1.6 - 2.10.3 + 2.10.4 2.10 0.13.0 org.spark-project.akka @@ -192,22 +192,22 @@ org.eclipse.jetty jetty-util - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-security - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-plus - 7.6.8.v20121106 + 8.1.14.v20131031 org.eclipse.jetty jetty-server - 7.6.8.v20121106 + 8.1.14.v20131031 com.google.guava @@ -380,7 +380,7 @@ lift-json_${scala.binary.version} 2.5.1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2549bc9710f1f..c5c697e8e2427 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -152,7 +152,7 @@ object SparkBuild extends Build { def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq( organization := "org.apache.spark", version := SPARK_VERSION, - scalaVersion := "2.10.3", + scalaVersion := "2.10.4", scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-target:" + SCALAC_JVM_VERSION), javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION), @@ -248,13 +248,13 @@ object SparkBuild extends Build { */ libraryDependencies ++= Seq( - "io.netty" % "netty-all" % "4.0.17.Final", - "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106", - "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106", + "io.netty" % "netty-all" % "4.0.17.Final", + "org.eclipse.jetty" % "jetty-server" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-util" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-plus" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-security" % "8.1.14.v20131031", /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */ - "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"), + "org.eclipse.jetty.orbit" % "javax.servlet" % "3.0.0.v201112011016" artifacts Artifact("javax.servlet", "jar", "jar"), "org.scalatest" %% "scalatest" % "1.9.1" % "test", "org.scalacheck" %% "scalacheck" % "1.10.0" % "test", "com.novocode" % "junit-interface" % "0.10" % "test", diff --git a/project/plugins.sbt b/project/plugins.sbt index 5aa8a1ec2409b..d787237ddc540 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,4 +1,4 @@ -scalaVersion := "2.10.3" +scalaVersion := "2.10.4" resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 5a307044ba123..0142256e90fb7 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -32,7 +32,7 @@ object SparkPluginDef extends Build { name := "spark-style", organization := "org.apache.spark", version := sparkVersion, - scalaVersion := "2.10.3", + scalaVersion := "2.10.4", scalacOptions := Seq("-unchecked", "-deprecation"), libraryDependencies ++= Dependencies.scalaStyle ) diff --git a/sql/README.md b/sql/README.md index 4192fecb92fb0..14d5555f0c713 100644 --- a/sql/README.md +++ b/sql/README.md @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.TestHive._ -Welcome to Scala version 2.10.3 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). +Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). Type in expressions to have them evaluated. Type :help for more information. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index e09182dd8d5df..6b58b9322c4bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -31,6 +31,7 @@ trait Catalog { alias: Option[String] = None): LogicalPlan def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit + def unregisterTable(databaseName: Option[String], tableName: String): Unit } class SimpleCatalog extends Catalog { @@ -40,7 +41,7 @@ class SimpleCatalog extends Catalog { tables += ((tableName, plan)) } - def dropTable(tableName: String) = tables -= tableName + def unregisterTable(databaseName: Option[String], tableName: String) = { tables -= tableName } def lookupRelation( databaseName: Option[String], @@ -87,6 +88,10 @@ trait OverrideCatalog extends Catalog { plan: LogicalPlan): Unit = { overrides.put((databaseName, tableName), plan) } + + override def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + overrides.remove((databaseName, tableName)) + } } /** @@ -104,4 +109,8 @@ object EmptyCatalog extends Catalog { def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } + + def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + throw new UnsupportedOperationException + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 31d42b9ee71a0..6f939e6c41f6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable { s"[${this.mkString(",")}]" def copy(): Row + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + var i = 0 + while (i < length) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 722ff517d250e..02fedd16b8d4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.types.{BooleanType, StringType} +object InterpretedPredicate { + def apply(expression: Expression): (Row => Boolean) = { + (r: Row) => expression.apply(r).asInstanceOf[Boolean] + } +} + trait Predicate extends Expression { self: Product => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cf3c06acce5b0..69bbbdc8943fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,8 +26,9 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution._ /** @@ -111,13 +112,42 @@ class SQLContext(@transient val sparkContext: SparkContext) result } + /** Returns the specified table as a SchemaRDD */ + def table(tableName: String): SchemaRDD = + new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = { + val currentTable = catalog.lookupRelation(None, tableName) + val asInMemoryRelation = + InMemoryColumnarTableScan(currentTable.output, executePlan(currentTable).executedPlan) + + catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation)) + } + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = { + EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { + // This is kind of a hack to make sure that if this was just an RDD registered as a table, + // we reregister the RDD as a table. + case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd)) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + catalog.registerTable(None, tableName, SparkLogicalPlan(e)) + case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") + } + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext val strategies: Seq[Strategy] = TopK :: PartialAggregation :: - SparkEquiInnerJoin :: + HashJoin :: ParquetOperations :: BasicOperators :: CartesianProduct :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 86f9d3e0fa954..e35ac0b6ca95a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => - object SparkEquiInnerJoin extends Strategy { + object HashJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) => logger.debug(s"Considering join: ${predicates ++ condition}") @@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val leftKeys = joinKeys.map(_._1) val rightKeys = joinKeys.map(_._2) - val joinOp = execution.SparkEquiInnerJoin( - leftKeys, rightKeys, planLater(left), planLater(right)) + val joinOp = execution.HashJoin( + leftKeys, rightKeys, BuildRight, planLater(left), planLater(right)) // Make sure other conditions are met if present. if (otherPredicates.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index f0d21143ba5d1..c89dae9358bf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -17,21 +17,22 @@ package org.apache.spark.sql.execution -import scala.collection.mutable +import scala.collection.mutable.{ArrayBuffer, BitSet} -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} -import org.apache.spark.rdd.PartitionLocalRDDFunctions._ +sealed abstract class BuildSide +case object BuildLeft extends BuildSide +case object BuildRight extends BuildSide -case class SparkEquiInnerJoin( +case class HashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], + buildSide: BuildSide, left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -40,33 +41,93 @@ case class SparkEquiInnerJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + val (buildPlan, streamedPlan) = buildSide match { + case BuildLeft => (left, right) + case BuildRight => (right, left) + } + + val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + def output = left.output ++ right.output - def execute() = attachTree(this, "execute") { - val leftWithKeys = left.execute().mapPartitions { iter => - val generateLeftKeys = new Projection(leftKeys, left.output) - iter.map(row => (generateLeftKeys(row), row.copy())) - } + @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val streamSideKeyGenerator = + () => new MutableProjection(streamedKeys, streamedPlan.output) - val rightWithKeys = right.execute().mapPartitions { iter => - val generateRightKeys = new Projection(rightKeys, right.output) - iter.map(row => (generateRightKeys(row), row.copy())) - } + def execute() = { - // Do the join. - val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys)) - // Drop join keys and merge input tuples. - joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) } - } + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + // TODO: Use Spark's HashMap implementation. + val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + var currentRow: Row = null + + // Create a mapping of buildKeys -> rows + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new ArrayBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentMatchPosition: Int = -1 - /** - * Filters any rows where the any of the join keys is null, ensuring three-valued - * logic for the equi-join conditions. - */ - protected def filterNulls(rdd: RDD[(Row, Row)]) = - rdd.filter { - case (key: Seq[_], _) => !key.exists(_ == null) + // Mutable per row objects. + private[this] val joinRow = new JoinedRow + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || + (streamIter.hasNext && fetchNext()) + + override final def next() = { + val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) + currentMatchPosition += 1 + ret + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatches = null + currentMatchPosition = -1 + + while (currentHashMatches == null && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatches = hashTable.get(joinKeys.currentValue) + } + } + + if (currentHashMatches == null) { + false + } else { + currentMatchPosition = 0 + true + } + } + } } + } } case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { @@ -95,17 +156,19 @@ case class BroadcastNestedLoopJoin( def right = broadcast @transient lazy val boundCondition = - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true)) + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) def execute() = { val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new mutable.ArrayBuffer[Row] - val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) + val matchedRows = new ArrayBuffer[Row] + // TODO: Use Spark's BitSet. + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow streamedIter.foreach { streamedRow => @@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin( while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { matchedRows += buildRow(streamedRow ++ broadcastedRow) matched = true includedBroadcastTuples += i diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala new file mode 100644 index 0000000000000..e5902c3cae381 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan + +class CachedTableSuite extends QueryTest { + TestData // Load test tables. + + test("read from cached table and uncache") { + TestSQLContext.cacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + + TestSQLContext.uncacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestSQLContext.uncacheTable("testData") + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index fc5057b73fe24..197b557cba5f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { DataSinks, Scripts, PartialAggregation, - SparkEquiInnerJoin, + HashJoin, BasicOperators, CartesianProduct, BroadcastNestedLoopJoin diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4f8353666a12b..29834a11f41dc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -141,6 +141,13 @@ class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { */ override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ??? + + /** + * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. + * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. + */ + override def unregisterTable( + databaseName: Option[String], tableName: String): Unit = ??? } object HiveMetastoreTypes extends RegexParsers { diff --git a/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala new file mode 100644 index 0000000000000..68d45e53cdf26 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.execution.HiveComparisonTest + +class CachedTableSuite extends HiveComparisonTest { + TestHive.loadTestTable("src") + + test("cache table") { + TestHive.cacheTable("src") + } + + createQueryTest("read from cached table", + "SELECT * FROM src LIMIT 1") + + test("check that table is cached and uncache") { + TestHive.table("src").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + TestHive.uncacheTable("src") + } + + createQueryTest("read from uncached table", + "SELECT * FROM src LIMIT 1") + + test("make sure table is uncached") { + TestHive.table("src").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestHive.uncacheTable("src") + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 7d18a0fcf7ba8..9ebf7b484f421 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -36,8 +36,9 @@ class RateLimitedOutputStreamSuite extends FunSuite { val stream = new RateLimitedOutputStream(underlying, desiredBytesPerSec = 10000) val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) } - // We accept anywhere from 4.0 to 4.99999 seconds since the value is rounded down. - assert(SECONDS.convert(elapsedNs, NANOSECONDS) === 4) + val seconds = SECONDS.convert(elapsedNs, NANOSECONDS) + assert(seconds >= 4, s"Seconds value ($seconds) is less than 4.") + assert(seconds <= 30, s"Took more than 30 seconds ($seconds) to write data.") assert(underlying.toString("UTF-8") === data) } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index c565f2dde24fc..3e4c739e34fe9 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -63,7 +63,10 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { userClass = value args = tail - case ("--args") :: value :: tail => + case ("--args" | "--arg") :: value :: tail => + if (args(0) == "--args") { + println("--args is deprecated. Use --arg instead.") + } userArgsBuffer += value args = tail @@ -146,8 +149,8 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { "Options:\n" + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + " --class CLASS_NAME Name of your application's main class (required)\n" + - " --args ARGS Arguments to be passed to your application's main class.\n" + - " Mutliple invocations are possible, each will be passed in order.\n" + + " --arg ARGS Argument to be passed to your application's main class.\n" + + " Multiple invocations are possible, each will be passed in order.\n" + " --num-executors NUM Number of executors to start (Default: 2)\n" + " --executor-cores NUM Number of cores for the executors (Default: 1).\n" + " --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: 512 Mb)\n" +