diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3d5eb0fcf263b..db91d90a7eaa4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,18 +19,16 @@ package org.apache.spark.mllib.tree import scala.util.control.Breaks._ +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.tree.model._ -import org.apache.spark.{SparkContext, Logging} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.Split import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} -import java.util.Random +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom /** @@ -38,49 +36,50 @@ import org.apache.spark.util.random.XORShiftRandom * supports both continuous and categorical features. * @param strategy The configuration parameters for the tree algorithm which specify the type * of algorithm (classification, regression, etc.), feature type (continuous, - * categorical), - * depth of the tree, quantile calculation strategy, etc. - */ + * categorical), depth of the tree, quantile calculation strategy, etc. + */ class DecisionTree private(val strategy: Strategy) extends Serializable with Logging { /** * Method to train a decision tree model over an RDD * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data - * for DecisionTree * @return a DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Cache input RDD for speedup during multiple passes + // Cache input RDD for speedup during multiple passes. input.cache() logDebug("algo = " + strategy.algo) - // Finding the splits and the corresponding bins (interval between the splits) using a sample + // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) logDebug("numSplits = " + bins(0).length) - // Noting numBins for the input data + // Set number of bins for the input data. strategy.numBins = bins(0).length - // The depth of the decision tree + // depth of the decision tree val maxDepth = strategy.maxDepth - // The max number of nodes possible given the depth of the tree + // the max number of nodes possible given the depth of the tree val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 - // Initalizing an array to hold filters applied to points for each node + // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) - // The filter at the top node is an empty list + // The filter at the top node is an empty list. filters(0) = List() - // Initializing an array to hold parent impurity calculations for each node + // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) - // Dummy value for top node (updated during first split calculation) + // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // The main-idea here is to perform level-wise training of the decision tree nodes thus - // reducing the passes over the data from l to log2(l) where l is the total number of nodes. - // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., - // the sample is only used for the split calculation at the node if the sampled would have - // still survived the filters of the parent nodes. + + /* + * The main idea here is to perform level-wise training of the decision tree nodes thus + * reducing the passes over the data from l to log2(l) where l is the total number of nodes. + * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + * the sample is only used for the split calculation at the node if the sampled would have + * still survived the filters of the parent nodes. + */ // TODO: Convert for loop to while loop breakable { @@ -90,35 +89,32 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log logDebug("level = " + level) logDebug("#####################################") - // Find best split for all nodes at a level + // Find best split for all nodes at a level. val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters, splits, bins) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - // Extract info for nodes at the current level + // Extract info for nodes at the current level. extractNodeInfo(nodeSplitStats, level, index, nodes) - // Extract info for nodes at the next lower level + // Extract info for nodes at the next lower level. extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, filters) logDebug("final best split = " + nodeSplitStats._1) - } require(scala.math.pow(2, level) == splitsStatsForLevel.length) - // Check whether all the nodes at the current level at leaves + // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) - if (allLeaf) break //no more tree construction - + if (allLeaf) break // no more tree construction } } - // Initialize the top or root node of the tree + // Initialize the top or root node of the tree. val topNode = nodes(0) - // Build the full tree using the node info calculated in the level-wise best split calculations + // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) - // Return a decision tree model - return new DecisionTreeModel(topNode, strategy.algo) + new DecisionTreeModel(topNode, strategy.algo) } /** @@ -128,9 +124,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log nodeSplitStats: (Split, InformationGainStats), level: Int, index: Int, - nodes: Array[Node]) - : Unit = { - + nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 val nodeIndex = scala.math.pow(2, level).toInt - 1 + index @@ -149,13 +143,11 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), parentImpurities: Array[Double], - filters: Array[List[Filter]]) - : Unit = { - + filters: Array[List[Filter]]): Unit = { // 0 corresponds to the left child node and 1 corresponds to the right child node. // TODO: Convert to while loop for (i <- 0 to 1) { - // Calculating the index of the node from the node level and the index at the current level + // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { @@ -184,7 +176,7 @@ object DecisionTree extends Serializable with Logging { * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree * @param strategy The configuration parameters for the tree algorithm which specify the type - * of algoritm (classification, regression, etc.), feature type (continuous, + * of algorithm (classification, regression, etc.), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return a DecisionTreeModel that can be used for prediction */ @@ -196,7 +188,7 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model over an RDD * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as * training data - * @param algo algo classification or regression + * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation * @param maxDepth maxDepth maximum depth of the tree * @return a DecisionTreeModel that can be used for prediction @@ -205,8 +197,7 @@ object DecisionTree extends Serializable with Logging { input: RDD[LabeledPoint], algo: Algo, impurity: Impurity, - maxDepth: Int) - : DecisionTreeModel = { + maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) } @@ -235,8 +226,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]) - : DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).train(input: RDD[LabeledPoint]) @@ -264,44 +254,42 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]) - : Array[(Split, InformationGainStats)] = { - - - // The high-level description for the best split optimizations are noted here. - // - // *Level-wise training* - // We perform bin calculations for all nodes at the given level to avoid making multiple - // passes over the data. Thus, for a slightly increased computation and storage cost we save - // several iterations over the data especially at higher levels of the decision tree. - // - // *Bin-wise computation* - // We use a bin-wise best split computation strategy instead of a straightforward best split - // computation strategy. Instead of analyzing each sample for contribution to the left/right - // child node impurity of every split, we first categorize each feature of a sample into a - // bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, - // is ordered (read ordering for categorical variables in the findSplitsBins method), - // we exploit this structure to calculate aggregates for bins and then use these aggregates - // to calculate information gain for each split. - // - // *Aggregation over partitions* - // Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know - // the number of splits in advance. Thus, we store the aggregates (at the appropriate - // indices) in a single array for all bins and rely upon the RDD aggregate method to - // drastically reduce the communication overhead. - - // Implementation below - - // Common calculations for multiple nested methods + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + + /* + * The high-level description for the best split optimizations are noted here. + * + * *Level-wise training* + * We perform bin calculations for all nodes at the given level to avoid making multiple + * passes over the data. Thus, for a slightly increased computation and storage cost we save + * several iterations over the data especially at higher levels of the decision tree. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * is ordered (read ordering for categorical variables in the findSplitsBins method), + * we exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt logDebug("numNodes = " + numNodes) - // Find the number of features by looking at the first sample + // Find the number of features by looking at the first sample. val numFeatures = input.first().features.length logDebug("numFeatures = " + numFeatures) val numBins = strategy.numBins logDebug("numBins = " + numBins) - /** Find the filters used before reaching the current code */ + /** Find the filters used before reaching the current code. */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() @@ -312,13 +300,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Find whether the sample is valid input for the current node. In other words, - * does it pass through all the filters for the current node. - */ + * Find whether the sample is valid input for the current node, i.e., whether it passes through + * all the filters for the current node. + */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { - - // Leaf - if ((level > 0) & (parentFilters.length == 0) ){ + // leaf + if ((level > 0) & (parentFilters.length == 0)) { return false } @@ -331,39 +318,37 @@ object DecisionTree extends Serializable with Logging { val categories = filter.split.categories val isFeatureContinuous = filter.split.featureType == Continuous val feature = features(featureIndex) - if (isFeatureContinuous){ + if (isFeatureContinuous) { comparison match { - case(-1) => if (feature > threshold) return false - case(1) => if (feature <= threshold) return false + case -1 => if (feature > threshold) return false + case 1 => if (feature <= threshold) return false } } else { val containsFeature = categories.contains(feature) comparison match { - case(-1) => if (!containsFeature) return false - case(1) => if (containsFeature) return false + case -1 => if (!containsFeature) return false + case 1 => if (containsFeature) return false } } } - //Return true when the sample is valid for all filters + // Return true when the sample is valid for all filters. true } /** - * Find bin for one feature + * Find bin for one feature. */ def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean) - : Int = { - + isFeatureContinuous: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) /** - * Binary search helper method for continuous feature + * Binary search helper method for continuous feature. */ def binarySearchForBins(): Int = { var left = 0 @@ -376,7 +361,7 @@ object DecisionTree extends Serializable with Logging { if ((lowThreshold < feature) & (highThreshold >= feature)){ return mid } - else if ((lowThreshold >= feature)){ + else if (lowThreshold >= feature) { right = mid - 1 } else { @@ -387,9 +372,9 @@ object DecisionTree extends Serializable with Logging { } /** - * Sequential search helper method to find bin for categorical feature + * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeature() : Int = { + def sequentialBinSearchForCategoricalFeature(): Int = { val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) var binIndex = 0 while (binIndex < numCategoricalBins) { @@ -404,7 +389,7 @@ object DecisionTree extends Serializable with Logging { -1 } - if (isFeatureContinuous){ + if (isFeatureContinuous) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() if (binIndex == -1){ @@ -424,27 +409,26 @@ object DecisionTree extends Serializable with Logging { /** * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: - * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, * where b_ij is an integer between 0 and numBins - 1. - * Invalid sample is denoted by noting bin for feature 1 as -1 + * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { - - // calculating bin index and label per feature per node + // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) - // Find out whether the sample qualifies for the particular node + // Find out whether the sample qualifies for the particular node. val sampleValid = isSampleValid(parentFilters, labeledPoint) val shift = 1 + numFeatures * nodeIndex if (!sampleValid) { - // marking one bin as -1 is sufficient + // Mark one bin as -1 is sufficient. arr(shift) = InvalidBinIndex } else { var featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) featureIndex += 1 @@ -461,33 +445,33 @@ object DecisionTree extends Serializable with Logging { * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size - * 2*numSplits*numFeatures*numNodes for classification - * @param arr Array[Double] of size 1+(numFeatures*numNodes) - * @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes - * for classification + * 2 * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures * numNodes for classification */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - // Iterating over all nodes + // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { - // Checking whether the instance was valid for this nodeIndex + // Check whether the instance was valid for this nodeIndex. val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - // Actual class label + // actual class label val label = arr(0) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Finding the bin index for this feature + // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex - // Updating the left or right count for one bin + // Update the left or right count for one bin. val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { - case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 - case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 + case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 } featureIndex += 1 } @@ -501,34 +485,33 @@ object DecisionTree extends Serializable with Logging { * the count, sum, sum of squares of one of the p bins is incremented. * * @param agg Array[Double] storing aggregate calculation of size - * 3*numSplits*numFeatures*numNodes for classification - * @param arr Array[Double] of size 1+(numFeatures*numNodes) + * 3 * numSplits * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size - * 3*numSplits*numFeatures*numNodes for regression + * 3 * numSplits * numFeatures * numNodes for regression */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { - // Iterating over all nodes + // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { - // Checking whether the instance was valid for this nodeIndex + // Check whether the instance was valid for this nodeIndex. val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { - // Actual class label + // actual class label val label = arr(0) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Finding the bin index for this feature + // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex - // updating count, sum, sum^2 for one bin + // Update count, sum, and sum^2 for one bin. val aggShift = 3 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label agg(aggIndex + 2) = agg(aggIndex + 2) + label*label - // increment featureIndex featureIndex += 1 } } @@ -547,15 +530,15 @@ object DecisionTree extends Serializable with Logging { agg } - // Calculating bin aggregate length for classification or regression + // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { - case Classification => 2*numBins * numFeatures * numNodes - case Regression => 3*numBins * numFeatures * numNodes + case Classification => 2 * numBins * numFeatures * numNodes + case Regression => 3 * numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) /** - * Combines the aggregates from partitions + * Combines the aggregates from partitions. * @param agg1 Array containing aggregates from one or more partitions * @param agg2 Array containing aggregates from one or more partitions * @return Combined aggregate from agg1 and agg2 @@ -563,24 +546,24 @@ object DecisionTree extends Serializable with Logging { def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { var index = 0 val combinedAggregate = new Array[Double](binAggregateLength) - while (index < binAggregateLength){ + while (index < binAggregateLength) { combinedAggregate(index) = agg1(index) + agg2(index) index += 1 } combinedAggregate } - // find feature bins for all nodes at a level + // Find feature bins for all nodes at a level. val binMappedRDD = input.map(x => findBinsForLevel(x)) - // calculate bin aggregates + // Calculate bin aggregates. val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) } logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits based upon left/right split aggregates + * Calculates the information gain for all splits based upon left/right split aggregates. * @param leftNodeAgg left node aggregates * @param featureIndex feature index * @param splitIndex split index @@ -593,12 +576,9 @@ object DecisionTree extends Serializable with Logging { featureIndex: Int, splitIndex: Int, rightNodeAgg: Array[Array[Double]], - topImpurity: Double) - : InformationGainStats = { - + topImpurity: Double): InformationGainStats = { strategy.algo match { - case Classification => { - + case Classification => val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) val leftCount = left0Count + left1Count @@ -611,7 +591,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { - // Calculating impurity for root node + // Calculate impurity for root node. strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) } } @@ -640,8 +620,7 @@ object DecisionTree extends Serializable with Logging { val predict = (left1Count + right1Count) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - } - case Regression => { + case Regression => val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) @@ -654,7 +633,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { - // Calculating impurity for root node + // Calculate impurity for root node. val count = leftCount + rightCount val sum = leftSum + rightSum val sumSquares = leftSumSquares + rightSumSquares @@ -687,31 +666,27 @@ object DecisionTree extends Serializable with Logging { val predict = (leftSum + rightSum) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - - } } } /** - * Extracts left and right split aggregates + * Extracts left and right split aggregates. * @param binData Array[Double] of size 2*numFeatures*numSplits * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( - binData: Array[Double]) - : (Array[Array[Double]], Array[Array[Double]]) = { - + binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { strategy.algo match { - case Classification => { - // Initializing left and right split aggregates + case Classification => + // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // shift for this featureIndex - val shift = 2*featureIndex*numBins + val shift = 2 * featureIndex * numBins // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) @@ -723,7 +698,7 @@ object DecisionTree extends Serializable with Logging { rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1) - // Iterating over all splits + // Iterate over all splits. var splitIndex = 1 while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a @@ -747,17 +722,15 @@ object DecisionTree extends Serializable with Logging { featureIndex += 1 } (leftNodeAgg, rightNodeAgg) - } - - case Regression => { - // Initializing left and right split aggregates + case Regression => + // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // shift for this featureIndex - val shift = 3*featureIndex*numBins + val shift = 3 * featureIndex * numBins // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) @@ -771,55 +744,49 @@ object DecisionTree extends Serializable with Logging { rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2) - // Iterating over all splits + // Iterate over all splits. var splitIndex = 1 while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(3 * splitIndex) - = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) - leftNodeAgg(featureIndex)(3 * splitIndex + 1) - = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) - leftNodeAgg(featureIndex)(3 * splitIndex + 2) - = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) // calculating right node aggregate for a split as a sum of right node aggregate of a // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) - = binData(shift + (3 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) - = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) - = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) splitIndex += 1 } featureIndex += 1 } (leftNodeAgg, rightNodeAgg) - } } } /** - * Calculates information gain for all nodes splits + * Calculates information gain for all nodes splits. */ def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], - nodeImpurity: Double) - : Array[Array[InformationGainStats]] = { - + nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins -1) { + for (splitIndex <- 0 until numBins - 1) { gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, splitIndex, rightNodeAgg, nodeImpurity) } @@ -827,38 +794,37 @@ object DecisionTree extends Serializable with Logging { gains } - /** - * Find the best split for a node - * @param binData Array[Double] of size 2*numSplits*numFeatures + /** + * Find the best split for a node. + * @param binData Array[Double] of size 2 * numSplits * numFeatures * @param nodeImpurity impurity of the top node * @return tuple of split and information gain */ def binsToBestSplit( binData: Array[Double], - nodeImpurity: Double) - : (Split, InformationGainStats) = { + nodeImpurity: Double): (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) - //extract left right node aggregates + // Extract left right node aggregates. val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) - // calculate gains for all splits + // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) val (bestFeatureIndex,bestSplitIndex, gainStats) = { - // Initialization with infeasible values + // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) - // Iterating over features + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + // Iterate over features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Iterating over all splits + // Iterate over all splits. var splitIndex = 0 - while (splitIndex < numBins - 1){ - val gainStats = gains(featureIndex)(splitIndex) - if(gainStats.gain > bestGainStats.gain) { + while (splitIndex < numBins - 1) { + val gainStats = gains(featureIndex)(splitIndex) + if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex @@ -867,29 +833,28 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - (bestFeatureIndex,bestSplitIndex,bestGainStats) + (bestFeatureIndex, bestSplitIndex, bestGainStats) } logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) - (splits(bestFeatureIndex)(bestSplitIndex),gainStats) + + (splits(bestFeatureIndex)(bestSplitIndex), gainStats) } /** - * get bin data for one node + * Get bin data for one node. */ def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { - case Classification => { + case Classification => val shift = 2 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) binsForNode - } - case Regression => { + case Regression => val shift = 3 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode - } } } @@ -897,7 +862,7 @@ object DecisionTree extends Serializable with Logging { val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 - while (node < numNodes){ + while (node < numNodes) { val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) @@ -907,12 +872,9 @@ object DecisionTree extends Serializable with Logging { node += 1 } - //Return best splits bestSplits } - - /** * Returns split and bins for decision tree calculation. * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -920,14 +882,12 @@ object DecisionTree extends Serializable with Logging { * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing * parameters for construction the DecisionTree * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree - * .model.Split] of size (numFeatures,numSplits-1) and bins is an Array of [org.apache - * .spark.mllib.tree.model.Bin] of size (numFeatures,numSplits1) + * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy) - : (Array[Array[Split]], Array[Array[Bin]]) = { - + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() // Find the number of features by looking at the first sample @@ -937,15 +897,17 @@ object DecisionTree extends Serializable with Logging { val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - // I will also add a require statement ensuring #bins is always greater than the categories - // It's a limitation of the current implementation but a reasonable tradeoff since features - // with large number of categories get favored over continuous features. - if (strategy.categoricalFeaturesInfo.size > 0){ + /* + * TODO: Add a require statement ensuring #bins is always greater than the categories. + * It's a limitation of the current implementation but a reasonable trade-off since features + * with large number of categories get favored over continuous features. + */ + if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins >= maxCategoriesForFeatures) } - // Calculate the number of sample for approximate quantile calculation + // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 logDebug("fraction of data used for calculating quantiles = " + fraction) @@ -958,23 +920,23 @@ object DecisionTree extends Serializable with Logging { logDebug("stride = " + stride) strategy.quantileCalculationStrategy match { - case Sort => { - val splits = Array.ofDim[Split](numFeatures, numBins-1) + case Sort => + val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) - // Find all splits + // Find all splits. - // Iterating over all features + // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures){ - // Checking whether the feature is continuous + // Check whether the feature is continuous. val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble/numBins + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - for (index <- 0 until numBins-1) { - val sampleIndex = (index + 1)*stride.toInt + for (index <- 0 until numBins - 1) { + val sampleIndex = (index + 1) * stride.toInt val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } @@ -984,87 +946,78 @@ object DecisionTree extends Serializable with Logging { "of bins") // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centriod of their corresponding labels. - val centriodForCategories - = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) - - // Checking for missing categorical variables and putting them last in the sorted list - val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() for (i <- 0 until maxFeatureValue) { - if (centriodForCategories.contains(i)) { - fullCentriodForCategories(i) = centriodForCategories(i) + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) } else { - fullCentriodForCategories(i) = Double.MaxValue + fullCentroidForCategories(i) = Double.MaxValue } } - //bins sorted by centriods - val categoriesSortedByCentriod - = fullCentriodForCategories.toList.sortBy{_._2} + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - logDebug("centriod for categorical variable = " + categoriesSortedByCentriod) + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) var categoriesForSplit = List[Double]() - categoriesSortedByCentriod.iterator.zipWithIndex foreach { - case((key, value), index) => { + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => categoriesForSplit = key :: categoriesForSplit splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) bins(featureIndex)(index) = { - if(index == 0) { + if (index == 0) { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), Categorical, key) - } - else { + } else { new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical, key) } } - } } } featureIndex += 1 } - // Find all bins + // Find all bins. featureIndex = 0 - while (featureIndex < numFeatures){ + while (featureIndex < numFeatures) { val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { // bins for categorical variables are already assigned - bins(featureIndex)(0) - = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0), - Continuous,Double.MinValue) + if (isFeatureContinuous) { // Bins for categorical variables are already assigned. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) for (index <- 1 until numBins - 1){ val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Continuous, Double.MinValue) bins(featureIndex)(index) = bin } - bins(featureIndex)(numBins-1) - = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, - Continuous), Continuous, Double.MinValue) + bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } featureIndex += 1 } (splits,bins) - } - case MinMax => { + case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") - } - case ApproxHist => { + case ApproxHist => throw new UnsupportedOperationException("approximate histogram not supported yet.") - } } } - val usage = """ Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] """ - def main(args: Array[String]) { if (args.length < 2) { @@ -1093,20 +1046,20 @@ object DecisionTree extends Serializable with Logging { sys.exit(1) } } - val options = nextOption(Map(),argList) + val options = nextOption(Map(), argList) logDebug(options.toString()) - // Load training data + // Load training data. val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) - // Identify the type of algorithm + // Identify the type of algorithm. val algoStr = options.get('algo).get.toString val algo = algoStr match { case "Classification" => Classification case "Regression" => Regression } - // Identify the type of impurity + // Identify the type of impurity. val impurityStr = options.getOrElse('impurity, if (algo == Classification) "Gini" else "Variance").toString val impurity = impurityStr match { @@ -1115,22 +1068,22 @@ object DecisionTree extends Serializable with Logging { case "Variance" => Variance } - val maxDepth = options.getOrElse('maxDepth,"1").toString.toInt - val maxBins = options.getOrElse('maxBins,"100").toString.toInt + val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt + val maxBins = options.getOrElse('maxBins, "100").toString.toInt val strategy = new Strategy(algo, impurity, maxDepth, maxBins) val model = DecisionTree.train(trainData, strategy) - // Load test data + // Load test data. val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) // Measure algorithm accuracy - if (algo == Classification){ + if (algo == Classification) { val accuracy = accuracyScore(model, testData) logDebug("accuracy = " + accuracy) } - if (algo == Regression){ + if (algo == Regression) { val mse = meanSquaredError(model, testData) logDebug("mean square error = " + mse) } @@ -1140,7 +1093,7 @@ object DecisionTree extends Serializable with Logging { /** * Load labeled data from a file. The data format used here is - * , ... + * , ..., * where , are feature values in Double and is the corresponding label as Double. * * @param sc SparkContext @@ -1157,12 +1110,12 @@ object DecisionTree extends Serializable with Logging { } } - // TODO: Port this method to a generic metrics package + // TODO: Port this method to a generic metrics package. /** * Calculates the classifier accuracy. */ def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { + threshold: Double = 0.5): Double = { def predictedValue(features: Array[Double]) = { if (model.predict(features) < threshold) 0.0 else 1.0 } @@ -1175,9 +1128,12 @@ object DecisionTree extends Serializable with Logging { // TODO: Port this method to a generic metrics package /** - * Calculates the mean squared error for regression + * Calculates the mean squared error for regression. */ def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() } }