Skip to content

Commit

Permalink
Choose splits for continuous features in DecisionTree more adaptively
Browse files Browse the repository at this point in the history
  • Loading branch information
chouqin committed Oct 9, 2014
1 parent 14f222f commit af7cb79
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.SparkContext._
import scala.collection.mutable.ArrayBuffer


/**
Expand Down Expand Up @@ -912,16 +913,19 @@ object DecisionTree extends Serializable with Logging {
val numSplits = metadata.numSplits(featureIndex)
val numBins = metadata.numBins(featureIndex)
if (metadata.isContinuous(featureIndex)) {
val numSamples = sampledInput.length

val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
val featureSplits = findSplits(featureSamples, metadata.numSplits(featureIndex))
metadata.setNumBinForFeature(featureIndex, metadata.numSplits(featureIndex))
val numSplits = metadata.numSplits(featureIndex)
val numBins = metadata.numBins(featureIndex)
logDebug("numSplits= " + numSplits)

splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
logDebug("stride = " + stride)

for (splitIndex <- 0 until numSplits) {
val sampleIndex = splitIndex * stride.toInt
// Set threshold halfway in between 2 samples.
val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
}
Expand Down Expand Up @@ -1011,4 +1015,58 @@ object DecisionTree extends Serializable with Logging {
categories
}

/**
* Find splits for a continuous feature
* @param featureSamples
* @param numSplits
* @return
*/
private def findSplits(featureSamples: Array[Double], numSplits: Int): Array[Double] = {
/*
* Get count for each distinct value
*/
def getValueCount(arr: Array[Double]): Array[(Double, Int)] = {
val valueCount = new ArrayBuffer[(Double, Int)]
var index = 1
var currentValue = arr(0)
var currentCount = 1
while (index < arr.length) {
if (currentValue != arr(index)) {
valueCount.append((currentValue, currentCount))
currentCount = 1
currentValue = arr(index)
} else {
currentCount += 1
}
index += 1
}
valueCount.append((currentValue, currentCount))

valueCount.toArray
}

val valueCount = getValueCount(featureSamples)
if (valueCount.length <= numSplits) {
return valueCount.map(_._1)
}

val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
logDebug("stride = " + stride)

val splits = new ArrayBuffer[Double]
var index = 1
var currentCount = valueCount(0)._2
var expectedCount = stride
while (index < valueCount.length) {
if (math.abs(currentCount - expectedCount) <
math.abs(currentCount + valueCount(index)._2 - expectedCount)) {
splits.append(valueCount(index-1)._1)
expectedCount += stride
}
currentCount += valueCount(index)._2
index += 1
}

splits.toArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ private[tree] class DecisionTreeMetadata(
numBins(featureIndex) - 1
}


/**
*
*/
def setNumBinForFeature(featureIndex: Int, numBin: Int) {
require(isContinuous(featureIndex),
s"Can only set number of bin for continuous feature.")
numBins(featureIndex) = numBin
}

/**
* Indicates if feature subsampling is being used.
*/
Expand Down

0 comments on commit af7cb79

Please sign in to comment.