Skip to content

Commit

Permalink
mark decision tree APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Apr 7, 2014
1 parent 86b9e34 commit 0b674fa
Show file tree
Hide file tree
Showing 13 changed files with 37 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class NaiveBayesModel(
private val brzTheta = new BDM[Double](theta.length, theta(0).length)

{
// Need to put an extra pair of braces to prevent Scala treat `i` as a member.
// Need to put an extra pair of braces to prevent Scala treating `i` as a member.
var i = 0
while (i < theta.length) {
var j = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.mllib.linalg.{Vector, Vectors}

/**
* <span class="badge" style="float: right; background-color: #257080;">EXPERIMENTAL</span>
*
* A class that implements a decision tree algorithm for classification and regression. It
* supports both continuous and categorical features.
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of algorithm (classification, regression, etc.), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
*/
class DecisionTree private(val strategy: Strategy) extends Serializable with Logging {
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {

/**
* Method to train a decision tree model over an RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ class Strategy (
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@
package org.apache.spark.mllib.tree.impurity

/**
* <span class="badge" style="float: right; background-color: #257080;">EXPERIMENTAL</span>
*
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
* binary classification.
*/
object Entropy extends Impurity {

def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2)

/**
* <span class="badge badge-red" style="float: right;">DEVELOPER API</span>
*
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
*/
def calculate(c0: Double, c1: Double): Double = {
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
Expand All @@ -42,6 +46,6 @@ object Entropy extends Impurity {
}
}

def calculate(count: Double, sum: Double, sumSquares: Double): Double =
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Entropy.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.spark.mllib.tree.impurity

/**
* <span class="badge" style="float: right; background-color: #257080;">EXPERIMENTAL</span>
*
* Class for calculating the
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
* during binary classification.
*/
object Gini extends Impurity {

/**
* <span class="badge badge-red" style="float: right;">DEVELOPER API</span>
*
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
Expand All @@ -41,6 +45,6 @@ object Gini extends Impurity {
}
}

def calculate(count: Double, sum: Double, sumSquares: Double): Double =
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package org.apache.spark.mllib.tree.impurity

/**
* <span class="badge" style="float: right; background-color: #257080;">EXPERIMENTAL</span>
*
* Trait for calculating information gain.
*/
trait Impurity extends Serializable {

/**
* <span class="badge badge-red" style="float: right;">DEVELOPER API - UNSTABLE</span>
* <span class="badge badge-red" style="float: right;">DEVELOPER API</span>
*
* information calculation for binary classification
* @param c0 count of instances with label 0
Expand All @@ -33,7 +35,7 @@ trait Impurity extends Serializable {
def calculate(c0 : Double, c1 : Double): Double

/**
* <span class="badge badge-red" style="float: right;">DEVELOPER API - UNSTABLE</span>
* <span class="badge badge-red" style="float: right;">DEVELOPER API</span>
*
* information calculation for regression
* @param count number of instances
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
package org.apache.spark.mllib.tree.impurity

/**
* <span class="badge" style="float: right; background-color: #257080;">EXPERIMENTAL</span>
*
* Class for calculating variance during regression
*/
object Variance extends Impurity {
override def calculate(c0: Double, c1: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")

/**
* <span class="badge badge-red" style="float: right;">DEVELOPER API</span>
*
* variance calculation
* @param count number of instances
* @param sum sum of labels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.Vector

/**
* <span class="badge" style="float: right; background-color: #257080;">EXPERIMENTAL</span>
*
* Model to store the decision tree parameters
* @param topNode root node
* @param algo algorithm type -- classification or regression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package org.apache.spark.mllib.tree.model
* @param split split specifying the feature index, type and threshold
* @param comparison integer specifying <,=,>
*/
case class Filter(split: Split, comparison: Int) {
private[tree] case class Filter(split: Split, comparison: Int) {
// Comparison -1,0,1 signifies <.=,>
override def toString = " split = " + split + "comparison = " + comparison
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.mllib.tree.model

/**
* <span class="badge badge-red" style="float: right;">DEVELOPER API</span>
*
* Information gain statistics for each split
* @param gain information gain value
* @param impurity current node impurity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vector

/**
* <span class="badge" style="float: right; background-color: #44751E;">DEVELOPER API</span>
*
* Node in a decision tree
* @param id integer node id
* @param predict predicted value at the node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ case class Split(
feature: Int,
threshold: Double,
featureType: FeatureType,
categories: List[Double]){
categories: List[Double]) {

override def toString =
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
Expand All @@ -42,15 +42,15 @@ case class Split(
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyLowSplit(feature: Int, featureType: FeatureType)
private[tree] class DummyLowSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MinValue, featureType, List())

/**
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyHighSplit(feature: Int, featureType: FeatureType)
private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

/**
Expand All @@ -59,6 +59,6 @@ class DummyHighSplit(feature: Int, featureType: FeatureType)
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

0 comments on commit 0b674fa

Please sign in to comment.