diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b8164f64a7b04..aaa5a4fef6697 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -306,6 +306,20 @@ object DecisionTree extends Serializable with Logging { arr } + /** + Performs a sequential aggregation over a partition for classification. + + for p bins, k features, l nodes (level = log2(l)) storage is of the form: + b111_left_count,b111_right_count, .... , .. + .. bpk1_left_count, bpk1_right_count, .... , .. + .. bpkl_left_count, bpkl_right_count + + @param agg Array[Double] storing aggregate calculation of size + 2*numSplits*numFeatures*numNodes for classification + @param arr Array[Double] of size 1+(numFeatures*numNodes) + @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes + for classification + */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { for (node <- 0 until numNodes) { val validSignalIndex = 1 + numFeatures * node @@ -326,6 +340,20 @@ object DecisionTree extends Serializable with Logging { } } + /** + Performs a sequential aggregation over a partition for regression. + + for p bins, k features, l nodes (level = log2(l)) storage is of the form: + b111_count,b111_sum, b111_sum_squares .... , .. + .. bpk1_count, bpk1_sum, bpk1_sum_squares, .... , .. + .. bpkl_count, bpkl_sum, bpkl_sum_squares + + @param agg Array[Double] storing aggregate calculation of size + 3*numSplits*numFeatures*numNodes for classification + @param arr Array[Double] of size 1+(numFeatures*numNodes) + @return Array[Double] storing aggregate calculation of size 3*numSplits*numFeatures*numNodes + for regression + */ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { for (node <- 0 until numNodes) { val validSignalIndex = 1 + numFeatures * node @@ -354,11 +382,11 @@ object DecisionTree extends Serializable with Logging { .. bpk1_left_count, bpk1_right_count, .... , .. .. bpkl_left_count, bpkl_right_count - @param agg Array[Double] storing aggregate calculation of size - 2*numSplits*numFeatures*numNodes for classification + @param agg Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression @param arr Array[Double] of size 1+(numFeatures*numNodes) - @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes - for classification + @return Array[Double] storing aggregate calculation of size 2*numSplits*numFeatures*numNodes for classification + and 3*numSplits*numFeatures*numNodes for regression */ def binSeqOp(agg : Array[Double], arr: Array[Double]) : Array[Double] = { strategy.algo match { @@ -411,7 +439,15 @@ object DecisionTree extends Serializable with Logging { logDebug("binAggregates.length = " + binAggregates.length) //binAggregates.foreach(x => logDebug(x)) - + /** + * Calculates the information gain for all splits + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ def calculateGainForSplit(leftNodeAgg: Array[Array[Double]], featureIndex: Int, splitIndex: Int,