Skip to content

Commit

Permalink
unit tests for categorical features
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <[email protected]>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent f067d68 commit 5841c28
Showing 1 changed file with 191 additions and 37 deletions.
228 changes: 191 additions & 37 deletions mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ import org.apache.spark.SparkContext._
import org.jblas._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import scala.collection.mutable
import org.apache.spark.mllib.tree.configuration.FeatureType._

class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {

Expand All @@ -56,7 +57,6 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins.length==2)
assert(splits(0).length==99)
assert(bins(0).length==100)
//println(splits(1)(98))
}

test("split and bin calculation for categorical variables"){
Expand All @@ -69,13 +69,71 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(bins.length==2)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(0)(0))
println(splits(0)(1))
println(bins(0)(0))
println(splits(1)(0))
println(splits(1)(1))
println(bins(1)(0))
//TODO: Add asserts

//Checking splits

assert(splits(0)(0).feature == 0)
assert(splits(0)(0).threshold == Double.MinValue)
assert(splits(0)(0).featureType == Categorical)
assert(splits(0)(0).categories.length == 1)
assert(splits(0)(0).categories.contains(1.0))


assert(splits(0)(1).feature == 0)
assert(splits(0)(1).threshold == Double.MinValue)
assert(splits(0)(1).featureType == Categorical)
assert(splits(0)(1).categories.length == 2)
assert(splits(0)(1).categories.contains(1.0))
assert(splits(0)(1).categories.contains(0.0))

assert(splits(0)(2) == null)

assert(splits(1)(0).feature == 1)
assert(splits(1)(0).threshold == Double.MinValue)
assert(splits(1)(0).featureType == Categorical)
assert(splits(1)(0).categories.length == 1)
assert(splits(1)(0).categories.contains(0.0))


assert(splits(1)(1).feature == 1)
assert(splits(1)(1).threshold == Double.MinValue)
assert(splits(1)(1).featureType == Categorical)
assert(splits(1)(1).categories.length == 2)
assert(splits(1)(1).categories.contains(1.0))
assert(splits(1)(1).categories.contains(0.0))

assert(splits(1)(2) == null)


// Checks bins

assert(bins(0)(0).category == 1.0)
assert(bins(0)(0).lowSplit.categories.length == 0)
assert(bins(0)(0).highSplit.categories.length == 1)
assert(bins(0)(0).highSplit.categories.contains(1.0))

assert(bins(0)(1).category == 0.0)
assert(bins(0)(1).lowSplit.categories.length == 1)
assert(bins(0)(1).lowSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.length == 2)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.contains(0.0))

assert(bins(0)(2).category == Double.MaxValue)

assert(bins(1)(0).category == 0.0)
assert(bins(1)(0).lowSplit.categories.length == 0)
assert(bins(1)(0).highSplit.categories.length == 1)
assert(bins(1)(0).highSplit.categories.contains(0.0))

assert(bins(1)(1).category == 1.0)
assert(bins(1)(1).lowSplit.categories.length == 1)
assert(bins(1)(1).lowSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.length == 2)
assert(bins(1)(1).highSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.contains(1.0))

assert(bins(1)(2).category == Double.MaxValue)

}

Expand All @@ -85,29 +143,106 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
assert(splits.length==2)
assert(bins.length==2)
assert(splits(0).length==99)
assert(bins(0).length==100)
println(splits(0)(0))
println(splits(0)(1))
println(splits(0)(2))
println(bins(0)(0))
println(bins(0)(1))
println(bins(0)(2))
println(splits(1)(0))
println(splits(1)(1))
println(splits(1)(2))
println(bins(1)(0))
println(bins(1)(1))
println(bins(0)(2))
println(bins(0)(3))
//TODO: Add asserts

}
//Checking splits

assert(splits(0)(0).feature == 0)
assert(splits(0)(0).threshold == Double.MinValue)
assert(splits(0)(0).featureType == Categorical)
assert(splits(0)(0).categories.length == 1)
assert(splits(0)(0).categories.contains(1.0))

assert(splits(0)(1).feature == 0)
assert(splits(0)(1).threshold == Double.MinValue)
assert(splits(0)(1).featureType == Categorical)
assert(splits(0)(1).categories.length == 2)
assert(splits(0)(1).categories.contains(1.0))
assert(splits(0)(1).categories.contains(0.0))

assert(splits(0)(2).feature == 0)
assert(splits(0)(2).threshold == Double.MinValue)
assert(splits(0)(2).featureType == Categorical)
assert(splits(0)(2).categories.length == 3)
assert(splits(0)(2).categories.contains(1.0))
assert(splits(0)(2).categories.contains(0.0))
assert(splits(0)(2).categories.contains(2.0))

assert(splits(0)(3) == null)

assert(splits(1)(0).feature == 1)
assert(splits(1)(0).threshold == Double.MinValue)
assert(splits(1)(0).featureType == Categorical)
assert(splits(1)(0).categories.length == 1)
assert(splits(1)(0).categories.contains(0.0))

assert(splits(1)(1).feature == 1)
assert(splits(1)(1).threshold == Double.MinValue)
assert(splits(1)(1).featureType == Categorical)
assert(splits(1)(1).categories.length == 2)
assert(splits(1)(1).categories.contains(1.0))
assert(splits(1)(1).categories.contains(0.0))

assert(splits(1)(2).feature == 1)
assert(splits(1)(2).threshold == Double.MinValue)
assert(splits(1)(2).featureType == Categorical)
assert(splits(1)(2).categories.length == 3)
assert(splits(1)(2).categories.contains(1.0))
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(2.0))

assert(splits(1)(3) == null)


// Checks bins

assert(bins(0)(0).category == 1.0)
assert(bins(0)(0).lowSplit.categories.length == 0)
assert(bins(0)(0).highSplit.categories.length == 1)
assert(bins(0)(0).highSplit.categories.contains(1.0))

assert(bins(0)(1).category == 0.0)
assert(bins(0)(1).lowSplit.categories.length == 1)
assert(bins(0)(1).lowSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.length == 2)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.contains(0.0))

assert(bins(0)(2).category == 2.0)
assert(bins(0)(2).lowSplit.categories.length == 2)
assert(bins(0)(2).lowSplit.categories.contains(1.0))
assert(bins(0)(2).lowSplit.categories.contains(0.0))
assert(bins(0)(2).highSplit.categories.length == 3)
assert(bins(0)(2).highSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.contains(0.0))
assert(bins(0)(2).highSplit.categories.contains(2.0))

assert(bins(0)(3).category == Double.MaxValue)

assert(bins(1)(0).category == 0.0)
assert(bins(1)(0).lowSplit.categories.length == 0)
assert(bins(1)(0).highSplit.categories.length == 1)
assert(bins(1)(0).highSplit.categories.contains(0.0))

assert(bins(1)(1).category == 1.0)
assert(bins(1)(1).lowSplit.categories.length == 1)
assert(bins(1)(1).lowSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.length == 2)
assert(bins(1)(1).highSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.contains(1.0))

assert(bins(1)(2).category == 2.0)
assert(bins(1)(2).lowSplit.categories.length == 2)
assert(bins(1)(2).lowSplit.categories.contains(0.0))
assert(bins(1)(2).lowSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.length == 3)
assert(bins(1)(2).highSplit.categories.contains(0.0))
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(2.0))

assert(bins(1)(3).category == Double.MaxValue)

//TODO: Test max feature value > num bins

}

test("classification stump with all categorical variables"){
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
Expand All @@ -117,22 +252,41 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
println(bestSplits(0)._1)
println(bestSplits(0)._2)
//TODO: Add asserts

val split = bestSplits(0)._1
assert(split.categories.length == 1)
assert(split.categories.contains(1.0))
assert(split.featureType == Categorical)
assert(split.threshold == Double.MinValue)

val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict > 0.4)
assert(stats.predict < 0.5)
assert(stats.impurity > 0.2)

}

test("regression stump with all categorical variables"){
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length == 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification,Gini,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val strategy = new Strategy(Regression,Variance,3,100,categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
strategy.numBins = 100
val bestSplits = DecisionTree.findBestSplits(rdd,new Array(7),strategy,0,Array[List[Filter]](),splits,bins)
println(bestSplits(0)._1)
println(bestSplits(0)._2)
//TODO: Add asserts

val split = bestSplits(0)._1
assert(split.categories.length == 1)
assert(split.categories.contains(1.0))
assert(split.featureType == Categorical)
assert(split.threshold == Double.MinValue)

val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict > 0.4)
assert(stats.predict < 0.5)
assert(stats.impurity > 0.2)
}


Expand All @@ -157,7 +311,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
assert(0==bestSplits(0)._2.gain)
assert(0==bestSplits(0)._2.leftImpurity)
assert(0==bestSplits(0)._2.rightImpurity)
println(bestSplits(0)._2.predict)

}

test("stump with fixed label 1 for Gini"){
Expand Down

0 comments on commit 5841c28

Please sign in to comment.