From c0e522b7d1f5e27c81d682e5c8c97543fb4242be Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 26 Jan 2014 19:11:43 -0800 Subject: [PATCH] updated predict and split threshold logic Signed-off-by: Manish Amde --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 9 +++++---- .../org/apache/spark/mllib/tree/DecisionTreeRunner.scala | 1 - .../spark/mllib/tree/model/DecisionTreeModel.scala | 2 +- .../spark/mllib/tree/model/InformationGainStats.scala | 7 ++++--- .../scala/org/apache/spark/mllib/tree/model/Node.scala | 1 + .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 8 ++++---- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ab2c9011dd93b..865a95c5025fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -211,7 +211,7 @@ object DecisionTree extends Serializable with Logging { val lowThreshold = bin.lowSplit.threshold val highThreshold = bin.highSplit.threshold val features = labeledPoint.features - if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) { + if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex))) { return binIndex } } @@ -400,7 +400,8 @@ object DecisionTree extends Serializable with Logging { } } - val predict = leftCount / (leftCount + rightCount) + //val predict = leftCount / (leftCount + rightCount) + val predict = (left1Count + right1Count) / (leftCount + rightCount) new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict) } @@ -672,8 +673,8 @@ object DecisionTree extends Serializable with Logging { //Find all bins for (featureIndex <- 0 until numFeatures){ - val isFeatureContinous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinous) { //bins for categorical variables are already assigned + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { //bins for categorical variables are already assigned bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),splits(featureIndex)(0),Continuous,Double.MinValue) for (index <- 1 until numBins - 1){ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala index ae18cb0aaa4e7..4e6ed768d55d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTreeRunner.scala @@ -133,7 +133,6 @@ object DecisionTreeRunner extends Logging { //TODO: Make these generic MLTable metrics def meanSquaredError(tree : DecisionTreeModel, data : RDD[LabeledPoint]) : Double = { val meanSumOfSquares = data.map(y => (tree.predict(y.features) - y.label)*(tree.predict(y.features) - y.label)).mean() - println("meanSumOfSquares = " + meanSumOfSquares) meanSumOfSquares } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 587e549c34ca8..0da42e826984c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -24,7 +24,7 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl def predict(features : Array[Double]) = { algo match { case Classification => { - if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0 + if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0 } case Regression => { topNode.predictIfLeaf(features) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index b992684b2b05b..55d5893ee93c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -24,9 +24,10 @@ class InformationGainStats(val gain : Double, //val rightSamples : Long val predict : Double) extends Serializable { - override def toString = - "gain = " + gain + ", impurity = " + impurity + ", left impurity = " - + leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict + override def toString = { + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index fb63743848cc9..508b7b31d83b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -34,6 +34,7 @@ class Node ( val id : Int, def build(nodes : Array[Node]) : Unit = { logDebug("building node " + id + " at level " + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) if (!isLeaf) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 8d5ed343e0eb4..15b5b40b06532 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -157,7 +157,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + println(bestSplits(0)._2.predict) } test("stump with fixed label 1 for Gini"){ @@ -181,7 +181,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + assert(1==bestSplits(0)._2.predict) } @@ -207,7 +207,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + assert(0==bestSplits(0)._2.predict) } test("stump with fixed label 1 for Entropy"){ @@ -231,7 +231,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(0==bestSplits(0)._2.gain) assert(0==bestSplits(0)._2.leftImpurity) assert(0==bestSplits(0)._2.rightImpurity) - assert(0.01==bestSplits(0)._2.predict) + assert(1==bestSplits(0)._2.predict) }