diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b7492038445cc..3d5eb0fcf263b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -81,6 +81,8 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log // Each data sample is checked for validity w.r.t to each node at a given level -- i.e., // the sample is only used for the split calculation at the node if the sampled would have // still survived the filters of the parent nodes. + + // TODO: Convert for loop to while loop breakable { for (level <- 0 until maxDepth) { @@ -120,7 +122,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log } /** - * Extract the decision tree node information for th given tree level and node index + * Extract the decision tree node information for the given tree level and node index */ private def extractNodeInfo( nodeSplitStats: (Split, InformationGainStats), @@ -151,6 +153,7 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log : Unit = { // 0 corresponds to the left child node and 1 corresponds to the right child node. + // TODO: Convert to while loop for (i <- 0 to 1) { // Calculating the index of the node from the node level and the index at the current level val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i @@ -264,6 +267,31 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = { + + // The high-level description for the best split optimizations are noted here. + // + // *Level-wise training* + // We perform bin calculations for all nodes at the given level to avoid making multiple + // passes over the data. Thus, for a slightly increased computation and storage cost we save + // several iterations over the data especially at higher levels of the decision tree. + // + // *Bin-wise computation* + // We use a bin-wise best split computation strategy instead of a straightforward best split + // computation strategy. Instead of analyzing each sample for contribution to the left/right + // child node impurity of every split, we first categorize each feature of a sample into a + // bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + // is ordered (read ordering for categorical variables in the findSplitsBins method), + // we exploit this structure to calculate aggregates for bins and then use these aggregates + // to calculate information gain for each split. + // + // *Aggregation over partitions* + // Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + // the number of splits in advance. Thus, we store the aggregates (at the appropriate + // indices) in a single array for all bins and rely upon the RDD aggregate method to + // drastically reduce the communication overhead. + + // Implementation below + // Common calculations for multiple nested methods val numNodes = scala.math.pow(2, level).toInt logDebug("numNodes = " + numNodes) @@ -294,6 +322,7 @@ object DecisionTree extends Serializable with Logging { return false } + // Apply each filter and check sample validity. Return false when invalid condition found. for (filter <- parentFilters) { val features = labeledPoint.features val featureIndex = filter.split.feature @@ -316,12 +345,13 @@ object DecisionTree extends Serializable with Logging { } } + + //Return true when the sample is valid for all filters true } - // TODO: Unit test this /** - * Finds the right bin for the given feature + * Find bin for one feature */ def findBin( featureIndex: Int, @@ -332,9 +362,12 @@ object DecisionTree extends Serializable with Logging { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) + /** + * Binary search helper method for continuous feature + */ def binarySearchForBins(): Int = { var left = 0 - var right = binForFeatures.length-1 + var right = binForFeatures.length - 1 while (left <= right) { val mid = left + (right - left) / 2 val bin = binForFeatures(mid) @@ -353,13 +386,10 @@ object DecisionTree extends Serializable with Logging { -1 } - if (isFeatureContinuous){ - val binIndex = binarySearchForBins() - if (binIndex == -1){ - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex - } else { + /** + * Sequential search helper method to find bin for categorical feature + */ + def sequentialBinSearchForCategoricalFeature() : Int = { val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) var binIndex = 0 while (binIndex < numCategoricalBins) { @@ -371,26 +401,40 @@ object DecisionTree extends Serializable with Logging { } binIndex += 1 } - throw new UnknownError("no bin was found for categorical variable.") - + -1 } + if (isFeatureContinuous){ + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = sequentialBinSearchForCategoricalFeature() + if (binIndex == -1){ + throw new UnknownError("no bin was found for categorical variable.") + } + binIndex + } } /** - * Finds bins for all nodes (and all features) at a given level k features, - * l nodes (level = log2(l)). - * Storage label, b11, b12, b13, .., b1k, - * b21, b22, .. , b2k, - * bl1, bl2, .. , blk - * Denotes invalid sample for tree by noting bin for feature 1 as -1 + * Finds bins for all nodes (and all features) at a given level. + * For l nodes, k features the storage is as follows: + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk + * where b_ij is an integer between 0 and numBins - 1. + * Invalid sample is denoted by noting bin for feature 1 as -1 */ def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // calculating bin index and label per feature per node val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label - for (nodeIndex <- 0 until numNodes) { + var nodeIndex = 0 + while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) // Find out whether the sample qualifies for the particular node val sampleValid = isSampleValid(parentFilters, labeledPoint) @@ -406,17 +450,15 @@ object DecisionTree extends Serializable with Logging { featureIndex += 1 } } + nodeIndex += 1 } arr } /** - * Performs a sequential aggregation over a partition for classification. - * - * for p bins, k features, l nodes (level = log2(l)) storage is of the form: - * b111_left_count,b111_right_count, .... , .. - * .. bpk1_left_count, bpk1_right_count, .... , .. - * .. bpkl_left_count, bpkl_right_count + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size * 2*numSplits*numFeatures*numNodes for classification @@ -425,32 +467,38 @@ object DecisionTree extends Serializable with Logging { * for classification */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (nodeIndex <- 0 until numNodes) { + // Iterating over all nodes + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Checking whether the instance was valid for this nodeIndex val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { + // Actual class label val label = arr(0) - for (featureIndex <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Finding the bin index for this feature val arrShift = 1 + numFeatures * nodeIndex - val aggShift = 2 * numBins * numFeatures * nodeIndex val arrIndex = arrShift + featureIndex + // Updating the left or right count for one bin + val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label match { case (0.0) => agg(aggIndex) = agg(aggIndex) + 1 case (1.0) => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 } + featureIndex += 1 } } + nodeIndex += 1 } } /** - * Performs a sequential aggregation over a partition for regression. - * - * for p bins, k features, l nodes (level = log2(l)) storage is of the form: - * b111_count,b111_sum, b111_sum_squares .... , .. - * .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. - * .. bpkl_count, bpkl_sum, bpkl_sum_squares + * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. * * @param agg Array[Double] storing aggregate calculation of size * 3*numSplits*numFeatures*numNodes for classification @@ -459,37 +507,37 @@ object DecisionTree extends Serializable with Logging { * 3*numSplits*numFeatures*numNodes for regression */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { - for (nodeIndex <- 0 until numNodes) { + // Iterating over all nodes + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Checking whether the instance was valid for this nodeIndex val validSignalIndex = 1 + numFeatures * nodeIndex val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex if (isSampleValidForNode) { + // Actual class label val label = arr(0) - for (feature <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Finding the bin index for this feature val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // updating count, sum, sum^2 for one bin val aggShift = 3 * numBins * numFeatures * nodeIndex - val arrIndex = arrShift + feature - val aggIndex = aggShift + 3 * feature * numBins + arr(arrIndex).toInt * 3 - //count, sum, sum^2 + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 agg(aggIndex) = agg(aggIndex) + 1 agg(aggIndex + 1) = agg(aggIndex + 1) + label agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + // increment featureIndex + featureIndex += 1 } } + nodeIndex += 1 } } /** * Performs a sequential aggregation over a partition. - * for p bins, k features, l nodes (level = log2(l)) storage is of the form: - * b111_left_count,b111_right_count, .... , .... - * bpk1_left_count, bpk1_right_count, .... , ...., bpkl_left_count, bpkl_right_count - * @param agg Array[Double] storing aggregate calculation of size - * 2*numSplits*numFeatures*numNodes for classification and - * 3*numSplits*numFeatures*numNodes for regression - * @param arr Array[Double] of size 1+(numFeatures*numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2*numSplits*numFeatures*numNodes for classification and - * 3*numSplits*numFeatures*numNodes for regression */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { @@ -499,6 +547,7 @@ object DecisionTree extends Serializable with Logging { agg } + // Calculating bin aggregate length for classification or regression val binAggregateLength = strategy.algo match { case Classification => 2*numBins * numFeatures * numNodes case Regression => 3*numBins * numFeatures * numNodes @@ -512,27 +561,17 @@ object DecisionTree extends Serializable with Logging { * @return Combined aggregate from agg1 and agg2 */ def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => { - val combinedAggregate = new Array[Double](binAggregateLength) - for (index <- 0 until binAggregateLength){ - combinedAggregate(index) = agg1(index) + agg2(index) - } - combinedAggregate - } - case Regression => { - val combinedAggregate = new Array[Double](binAggregateLength) - for (index <- 0 until binAggregateLength){ - combinedAggregate(index) = agg1(index) + agg2(index) - } - combinedAggregate - } + var index = 0 + val combinedAggregate = new Array[Double](binAggregateLength) + while (index < binAggregateLength){ + combinedAggregate(index) = agg1(index) + agg2(index) + index += 1 } + combinedAggregate } - logDebug("input = " + input.count) + // find feature bins for all nodes at a level val binMappedRDD = input.map(x => findBinsForLevel(x)) - logDebug("binMappedRDD.count = " + binMappedRDD.count) // calculate bin aggregates val binAggregates = { @@ -541,7 +580,7 @@ object DecisionTree extends Serializable with Logging { logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits + * Calculates the information gain for all splits based upon left/right split aggregates * @param leftNodeAgg left node aggregates * @param featureIndex feature index * @param splitIndex split index @@ -572,6 +611,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { + // Calculating impurity for root node strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) } } @@ -614,6 +654,7 @@ object DecisionTree extends Serializable with Logging { if (level > 0) { topImpurity } else { + // Calculating impurity for root node val count = leftCount + rightCount val sum = leftSum + rightSum val sumSquares = leftSumSquares + rightSumSquares @@ -623,11 +664,11 @@ object DecisionTree extends Serializable with Logging { if (leftCount == 0) { return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum/rightCount) + rightSum / rightCount) } if (rightCount == 0) { return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum/leftCount) + Double.MinValue, leftSum / leftCount) } val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) @@ -644,16 +685,16 @@ object DecisionTree extends Serializable with Logging { } } - val predict = (leftSum + rightSum)/(leftCount + rightCount) + val predict = (leftSum + rightSum) / (leftCount + rightCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } } - /** + /** * Extracts left and right split aggregates - * @param binData Array[Double] of size 2*numFeatures*numSplits + * @param binData Array[Double] of size 2*numFeatures*numSplits * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ @@ -663,58 +704,90 @@ object DecisionTree extends Serializable with Logging { strategy.algo match { case Classification => { - + // Initializing left and right split aggregates val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - for (featureIndex <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex val shift = 2*featureIndex*numBins + + // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) + + // right node aggregate for the highest split rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * (numBins - 1))) rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) = binData(shift + (2 * (numBins - 1)) + 1) - for (splitIndex <- 1 until numBins - 1) { - leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2*splitIndex) + + + // Iterating over all splits + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2*splitIndex + 1) + + leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = binData(shift + (2 *(numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + + splitIndex += 1 } + featureIndex += 1 } (leftNodeAgg, rightNodeAgg) } case Regression => { - + // Initializing left and right split aggregates val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - for (featureIndex <- 0 until numFeatures) { + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex val shift = 3*featureIndex*numBins + // left node aggregate for the lowest split leftNodeAgg(featureIndex)(0) = binData(shift + 0) leftNodeAgg(featureIndex)(1) = binData(shift + 1) leftNodeAgg(featureIndex)(2) = binData(shift + 2) + + // right node aggregate for the highest split rightNodeAgg(featureIndex)(3 * (numBins - 2)) = binData(shift + (3 * (numBins - 1))) rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = binData(shift + (3 * (numBins - 1)) + 1) rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = binData(shift + (3 * (numBins - 1)) + 2) - for (splitIndex <- 1 until numBins - 1) { + + // Iterating over all splits + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split leftNodeAgg(featureIndex)(3 * splitIndex) - = binData(shift + 3*splitIndex) + + = binData(shift + 3 * splitIndex) + leftNodeAgg(featureIndex)(3 * splitIndex - 3) leftNodeAgg(featureIndex)(3 * splitIndex + 1) - = binData(shift + 3*splitIndex + 1) + + = binData(shift + 3 * splitIndex + 1) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) leftNodeAgg(featureIndex)(3 * splitIndex + 2) - = binData(shift + 3*splitIndex + 2) + + = binData(shift + 3 * splitIndex + 2) + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = binData(shift + (3 * (numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) @@ -724,13 +797,19 @@ object DecisionTree extends Serializable with Logging { rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + + splitIndex += 1 } + featureIndex += 1 } (leftNodeAgg, rightNodeAgg) } } } + /** + * Calculates information gain for all nodes splits + */ def calculateGainsForAllNodeSplits( leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], @@ -749,10 +828,10 @@ object DecisionTree extends Serializable with Logging { } /** - * Find the best split for a node given bin aggregate data + * Find the best split for a node * @param binData Array[Double] of size 2*numSplits*numFeatures * @param nodeImpurity impurity of the top node - * @return + * @return tuple of split and information gain */ def binsToBestSplit( binData: Array[Double], @@ -760,23 +839,33 @@ object DecisionTree extends Serializable with Logging { : (Split, InformationGainStats) = { logDebug("node impurity = " + nodeImpurity) + + //extract left right node aggregates val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + + // calculate gains for all splits val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) val (bestFeatureIndex,bestSplitIndex, gainStats) = { - var bestFeatureIndex = 0 - var bestSplitIndex = 0 // Initialization with infeasible values + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1) - for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1){ + // Iterating over features + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Iterating over all splits + var splitIndex = 0 + while (splitIndex < numBins - 1){ val gainStats = gains(featureIndex)(splitIndex) if(gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats bestFeatureIndex = featureIndex bestSplitIndex = splitIndex } + splitIndex += 1 } + featureIndex += 1 } (bestFeatureIndex,bestSplitIndex,bestGainStats) } @@ -786,8 +875,9 @@ object DecisionTree extends Serializable with Logging { (splits(bestFeatureIndex)(bestSplitIndex),gainStats) } - // Calculate best splits for all nodes at a given level - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + /** + * get bin data for one node + */ def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => { @@ -803,14 +893,21 @@ object DecisionTree extends Serializable with Logging { } } - for (node <- 0 until numNodes){ + // Calculate best splits for all nodes at a given level + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + // Iterating over all nodes at this level + var node = 0 + while (node < numNodes){ val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) logDebug("node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) + node += 1 } + + //Return best splits bestSplits } @@ -865,12 +962,15 @@ object DecisionTree extends Serializable with Logging { val splits = Array.ofDim[Split](numFeatures, numBins-1) val bins = Array.ofDim[Bin](numFeatures, numBins) - //Find all splits - for (featureIndex <- 0 until numFeatures){ + // Find all splits + + // Iterating over all features + var featureIndex = 0 + while (featureIndex < numFeatures){ + // Checking whether the feature is continuous val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble/numBins logDebug("stride = " + stride) for (index <- 0 until numBins-1) { @@ -880,15 +980,16 @@ object DecisionTree extends Serializable with Logging { } } else { val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number " + "of bins") + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centriod of their corresponding labels. val centriodForCategories = sampledInput.map(lp => (lp.features(featureIndex),lp.label)) .groupBy(_._1).mapValues(x => x.map(_._2).sum / x.map(_._1).length) - // Checking for missing categorical variables + // Checking for missing categorical variables and putting them last in the sorted list val fullCentriodForCategories = scala.collection.mutable.Map[Double,Double]() for (i <- 0 until maxFeatureValue) { if (centriodForCategories.contains(i)) { @@ -898,6 +999,7 @@ object DecisionTree extends Serializable with Logging { } } + //bins sorted by centriods val categoriesSortedByCentriod = fullCentriodForCategories.toList.sortBy{_._2} @@ -922,10 +1024,12 @@ object DecisionTree extends Serializable with Logging { } } } + featureIndex += 1 } // Find all bins - for (featureIndex <- 0 until numFeatures){ + featureIndex = 0 + while (featureIndex < numFeatures){ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { // bins for categorical variables are already assigned bins(featureIndex)(0) @@ -940,6 +1044,7 @@ object DecisionTree extends Serializable with Logging { = new Bin(splits(featureIndex)(numBins-2),new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } + featureIndex += 1 } (splits,bins) }