Skip to content

Commit

Permalink
code documentation
Browse files Browse the repository at this point in the history
Signed-off-by: Manish Amde <[email protected]>
  • Loading branch information
manishamde committed Feb 28, 2014
1 parent 9372779 commit 84f85d6
Show file tree
Hide file tree
Showing 14 changed files with 157 additions and 21 deletions.
31 changes: 18 additions & 13 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Algo._

/*
/**
A class that implements a decision tree algorithm for classification and regression.
It supports both continuous and categorical features.
Expand All @@ -40,7 +40,7 @@ quantile calculation strategy, etc.
*/
class DecisionTree(val strategy : Strategy) extends Serializable with Logging {

/*
/**
Method to train a decision tree model over an RDD
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
Expand Down Expand Up @@ -157,14 +157,14 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {

object DecisionTree extends Serializable with Logging {

/*
/**
Returns an Array[Split] of optimal splits for all nodes at a given level
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
for DecisionTree
for DecisionTree
@param parentImpurities Impurities for all parent nodes for the current level
@param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
parameters for construction the DecisionTree
parameters for construction the DecisionTree
@param level Level of the tree
@param filters Filter for all nodes at a given level
@param splits possible splits for all features
Expand Down Expand Up @@ -200,7 +200,7 @@ object DecisionTree extends Serializable with Logging {
}
}

/*
/**
Find whether the sample is valid input for the current node.
In other words, does it pass through all the filters for the current node.
*/
Expand Down Expand Up @@ -236,7 +236,9 @@ object DecisionTree extends Serializable with Logging {
true
}

/*Finds the right bin for the given feature*/
/**
Finds the right bin for the given feature
*/
def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous : Boolean) : Int = {

if (isFeatureContinuous){
Expand Down Expand Up @@ -266,7 +268,8 @@ object DecisionTree extends Serializable with Logging {

}

/*Finds bins for all nodes (and all features) at a given level
/**
Finds bins for all nodes (and all features) at a given level
k features, l nodes (level = log2(l))
Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
Denotes invalid sample for tree by noting bin for feature 1 as -1
Expand Down Expand Up @@ -343,7 +346,8 @@ object DecisionTree extends Serializable with Logging {
}
}

/*Performs a sequential aggregation over a partition.
/**
Performs a sequential aggregation over a partition.
for p bins, k features, l nodes (level = log2(l)) storage is of the form:
b111_left_count,b111_right_count, .... , ..
Expand All @@ -370,7 +374,8 @@ object DecisionTree extends Serializable with Logging {
}
logDebug("binAggregateLength = " + binAggregateLength)

/*Combines the aggregates from partitions
/**
Combines the aggregates from partitions
@param agg1 Array containing aggregates from one or more partitions
@param agg2 Array containing aggregates from one or more partitions
Expand Down Expand Up @@ -507,7 +512,7 @@ object DecisionTree extends Serializable with Logging {
}
}

/*
/**
Extracts left and right split aggregates
@param binData Array[Double] of size 2*numFeatures*numSplits
Expand Down Expand Up @@ -604,7 +609,7 @@ object DecisionTree extends Serializable with Logging {
gains
}

/*
/**
Find the best split for a node given bin aggregate data
@param binData Array[Double] of size 2*numSplits*numFeatures
Expand Down Expand Up @@ -669,7 +674,7 @@ object DecisionTree extends Serializable with Logging {
bestSplits
}

/*
/**
Returns split and bins for decision tree calculation.
@param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.spark.mllib.tree.configuration

/**
* Enum to select the algorithm for the decision tree
*/
object Algo extends Enumeration {
type Algo = Value
val Classification, Regression = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.spark.mllib.tree.configuration

/**
* Enum to describe whether a feature is "continuous" or "categorical"
*/
object FeatureType extends Enumeration {
type FeatureType = Value
val Continuous, Categorical = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.spark.mllib.tree.configuration

/**
* Enum for selecting the quantile calculation strategy
*/
object QuantileStrategy extends Enumeration {
type QuantileStrategy = Value
val Sort, MinMax, ApproxHist = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._

/**
* Stores all the configuration options for tree construction
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
*/
class Strategy (
val algo : Algo,
val impurity : Impurity,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException

/**
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
* binary classification.
*/
object Entropy extends Impurity {

def log2(x: Double) = scala.math.log(x) / scala.math.log(2)

/**
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
*/
def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,18 @@ package org.apache.spark.mllib.tree.impurity

import javax.naming.OperationNotSupportedException

/**
* Class for calculating the [[http://en.wikipedia.org/wiki/Gini_coefficient Gini
* coefficent]] during binary classification
*/
object Gini extends Impurity {

/**
* gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return gini coefficient value
*/
def calculate(c0 : Double, c1 : Double): Double = {
if (c0 == 0 || c1 == 0) {
0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@ package org.apache.spark.mllib.tree.impurity
import javax.naming.OperationNotSupportedException
import org.apache.spark.Logging

/**
* Class for calculating variance during regression
*/
object Variance extends Impurity with Logging {
def calculate(c0: Double, c1: Double): Double = throw new OperationNotSupportedException("Variance.calculate")

/**
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return
*/
def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum*sum)/count
squaredLoss/count
Expand Down
11 changes: 11 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Used for "binning" the features bins for faster best split calculation. For a continuous
* feature, a bin is determined by a low and a high "split". For a categorical feature,
* the a bin is determined using a single label value (category).
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
case class Bin(lowSplit : Split, highSplit : Split, featureType : FeatureType, category : Double) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@
*/
package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD

/**
* Model to store the decision tree parameters
* @param topNode root node
* @param algo algorithm type -- classification or regression
*/
class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable {

def predict(features : Array[Double]) = {
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features : Array[Double]) : Double = {
algo match {
case Classification => {
if (topNode.predictIfLeaf(features) < 0.5) 0.0 else 1.0
Expand All @@ -32,4 +43,15 @@ class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializabl
}
}

/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
*/
def predict(features: RDD[Array[Double]]): RDD[Double] = {
features.map(x => predict(x))
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
*/
package org.apache.spark.mllib.tree.model

/**
* Filter specifying a split and type of comparison to be applied on features
* @param split split specifying the feature index, type and threshold
* @param comparison integer specifying <,=,>
*/
case class Filter(split : Split, comparison : Int) {
// Comparison -1,0,1 signifies <.=,>
override def toString = " split = " + split + "comparison = " + comparison
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
*/
package org.apache.spark.mllib.tree.model

/**
* Information gain statistics for each split
* @param gain information gain value
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
*/
class InformationGainStats(
val gain : Double,
val impurity: Double,
Expand Down
10 changes: 10 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.FeatureType._

/**
* Node in a decision tree
* @param id integer node id
* @param predict predicted value at the node
* @param isLeaf whether the leaf is a node
* @param split split to calculate left and right nodes
* @param leftNode left child
* @param rightNode right child
* @param stats information gain stats
*/
class Node ( val id : Int,
val predict : Double,
val isLeaf : Boolean,
Expand Down
35 changes: 29 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ package org.apache.spark.mllib.tree.model

import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType

/**
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
*/
case class Split(
feature: Int,
threshold : Double,
Expand All @@ -29,12 +36,28 @@ case class Split(
", categories = " + categories
}

class DummyLowSplit(feature: Int, kind : FeatureType)
extends Split(feature, Double.MinValue, kind, List())
/**
* Split with minimum threshold for continuous features. Helps with the smallest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyLowSplit(feature: Int, featureType : FeatureType)
extends Split(feature, Double.MinValue, featureType, List())

class DummyHighSplit(feature: Int, kind : FeatureType)
extends Split(feature, Double.MaxValue, kind, List())
/**
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyHighSplit(feature: Int, featureType : FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

class DummyCategoricalSplit(feature: Int, kind : FeatureType)
extends Split(feature, Double.MaxValue, kind, List())
/**
* Split with no acceptable feature values for categorical features. Helps with the first bin
* creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyCategoricalSplit(feature: Int, featureType : FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

0 comments on commit 84f85d6

Please sign in to comment.